source: trunk/tests/test_classic_connection.py @ 730

Last change on this file since 730 was 730, checked in by cito, 4 years ago

Use query parameters instead of inline values

The single row methods of the DB wrapper class created queries with inline values
instead of passing them separately as parameters, even though our query method
does have this capability. Using query parameters also spares us a lot of quoting
and escaping that is necessary when passing values inline.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 59.0 KB
RevLine 
[539]1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
[553]15    import unittest2 as unittest  # for Python < 2.7
[539]16except ImportError:
17    import unittest
[547]18import threading
19import time
[607]20import os
[539]21
22import pg  # the module under test
23
24from decimal import Decimal
25
26# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
27# get our information from that.  Otherwise we use the defaults.
[613]28# These tests should be run with various PostgreSQL versions and databases
29# created with different encodings and locales.  Particularly, make sure the
30# tests are running against databases created with both SQL_ASCII and UTF8.
[539]31dbname = 'unittest'
32dbhost = None
33dbport = 5432
34
35try:
[640]36    from .LOCAL_PyGreSQL import *
[648]37except (ImportError, ValueError):
38    try:
39        from LOCAL_PyGreSQL import *
40    except ImportError:
41        pass
[539]42
[607]43try:
44    long
45except NameError:  # Python >= 3.0
46    long = int
[539]47
[613]48try:
49    unicode
50except NameError:  # Python >= 3.0
51    unicode = str
52
[607]53unicode_strings = str is not bytes
54
55windows = os.name == 'nt'
56
57# There is a known a bug in libpq under Windows which can cause
58# the interface to crash when calling PQhost():
59do_not_ask_for_host = windows
60do_not_ask_for_host_reason = 'libpq issue on Windows'
61
62
[539]63def connect():
64    """Create a basic pg connection to the test database."""
65    connection = pg.connect(dbname, dbhost, dbport)
66    connection.query("set client_min_messages=warning")
67    return connection
68
69
70class TestCanConnect(unittest.TestCase):
71    """Test whether a basic connection to PostgreSQL is possible."""
72
73    def testCanConnect(self):
74        try:
75            connection = connect()
[556]76        except pg.Error as error:
[539]77            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
78        try:
79            connection.close()
80        except pg.Error:
81            self.fail('Cannot close the database connection')
82
83
84class TestConnectObject(unittest.TestCase):
[641]85    """Test existence of basic pg connection methods."""
[539]86
87    def setUp(self):
88        self.connection = connect()
89
90    def tearDown(self):
91        try:
92            self.connection.close()
93        except pg.InternalError:
94            pass
95
[607]96    def is_method(self, attribute):
97        """Check if given attribute on the connection is a method."""
98        if do_not_ask_for_host and attribute == 'host':
99            return False
100        return callable(getattr(self.connection, attribute))
101
[580]102    def testClassName(self):
[597]103        self.assertEqual(self.connection.__class__.__name__, 'Connection')
[580]104
105    def testModuleName(self):
[597]106        self.assertEqual(self.connection.__class__.__module__, 'pg')
[580]107
[597]108    def testStr(self):
109        r = str(self.connection)
110        self.assertTrue(r.startswith('<pg.Connection object'), r)
111
112    def testRepr(self):
113        r = repr(self.connection)
114        self.assertTrue(r.startswith('<pg.Connection object'), r)
115
[539]116    def testAllConnectAttributes(self):
117        attributes = '''db error host options port
[680]118            protocol_version server_version status user'''.split()
[607]119        connection_attributes = [a for a in dir(self.connection)
120            if not a.startswith('__') and not self.is_method(a)]
[539]121        self.assertEqual(attributes, connection_attributes)
122
123    def testAllConnectMethods(self):
124        methods = '''cancel close endcopy
125            escape_bytea escape_identifier escape_literal escape_string
126            fileno get_notice_receiver getline getlo getnotify
127            inserttable locreate loimport parameter putline query reset
128            set_notice_receiver source transaction'''.split()
[607]129        connection_methods = [a for a in dir(self.connection)
130            if not a.startswith('__') and self.is_method(a)]
[539]131        self.assertEqual(methods, connection_methods)
132
133    def testAttributeDb(self):
134        self.assertEqual(self.connection.db, dbname)
135
136    def testAttributeError(self):
137        error = self.connection.error
138        self.assertTrue(not error or 'krb5_' in error)
139
[607]140    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
[539]141    def testAttributeHost(self):
142        def_host = 'localhost'
[543]143        self.assertIsInstance(self.connection.host, str)
144        self.assertEqual(self.connection.host, dbhost or def_host)
[539]145
146    def testAttributeOptions(self):
147        no_options = ''
148        self.assertEqual(self.connection.options, no_options)
149
150    def testAttributePort(self):
[545]151        def_port = 5432
[542]152        self.assertIsInstance(self.connection.port, int)
[545]153        self.assertEqual(self.connection.port, dbport or def_port)
[539]154
155    def testAttributeProtocolVersion(self):
156        protocol_version = self.connection.protocol_version
[542]157        self.assertIsInstance(protocol_version, int)
[539]158        self.assertTrue(2 <= protocol_version < 4)
159
160    def testAttributeServerVersion(self):
161        server_version = self.connection.server_version
[542]162        self.assertIsInstance(server_version, int)
[539]163        self.assertTrue(70400 <= server_version < 100000)
164
165    def testAttributeStatus(self):
166        status_ok = 1
[542]167        self.assertIsInstance(self.connection.status, int)
[539]168        self.assertEqual(self.connection.status, status_ok)
169
170    def testAttributeUser(self):
171        no_user = 'Deprecated facility'
172        user = self.connection.user
173        self.assertTrue(user)
[542]174        self.assertIsInstance(user, str)
[539]175        self.assertNotEqual(user, no_user)
176
177    def testMethodQuery(self):
178        query = self.connection.query
179        query("select 1+1")
180        query("select 1+$1", (1,))
181        query("select 1+$1+$2", (2, 3))
182        query("select 1+$1+$2", [2, 3])
183
184    def testMethodQueryEmpty(self):
185        self.assertRaises(ValueError, self.connection.query, '')
186
187    def testMethodEndcopy(self):
188        try:
189            self.connection.endcopy()
190        except IOError:
191            pass
192
193    def testMethodClose(self):
194        self.connection.close()
195        try:
196            self.connection.reset()
197        except (pg.Error, TypeError):
198            pass
199        else:
200            self.fail('Reset should give an error for a closed connection')
201        self.assertRaises(pg.InternalError, self.connection.close)
202        try:
203            self.connection.query('select 1')
204        except (pg.Error, TypeError):
205            pass
206        else:
207            self.fail('Query should give an error for a closed connection')
208        self.connection = connect()
209
[547]210    def testMethodReset(self):
211        query = self.connection.query
212        # check that client encoding gets reset
213        encoding = query('show client_encoding').getresult()[0][0].upper()
214        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
215        self.assertNotEqual(encoding, changed_encoding)
216        self.connection.query("set client_encoding=%s" % changed_encoding)
217        new_encoding = query('show client_encoding').getresult()[0][0].upper()
218        self.assertEqual(new_encoding, changed_encoding)
219        self.connection.reset()
220        new_encoding = query('show client_encoding').getresult()[0][0].upper()
221        self.assertNotEqual(new_encoding, changed_encoding)
222        self.assertEqual(new_encoding, encoding)
[539]223
[547]224    def testMethodCancel(self):
225        r = self.connection.cancel()
226        self.assertIsInstance(r, int)
227        self.assertEqual(r, 1)
228
229    def testCancelLongRunningThread(self):
230        errors = []
231
232        def sleep():
233            try:
234                self.connection.query('select pg_sleep(5)').getresult()
[556]235            except pg.ProgrammingError as error:
[547]236                errors.append(str(error))
237
238        thread = threading.Thread(target=sleep)
239        t1 = time.time()
240        thread.start()  # run the query
241        while 1:  # make sure the query is really running
242            time.sleep(0.1)
243            if thread.is_alive() or time.time() - t1 > 5:
244                break
245        r = self.connection.cancel()  # cancel the running query
246        thread.join()  # wait for the thread to end
247        t2 = time.time()
248
249        self.assertIsInstance(r, int)
250        self.assertEqual(r, 1)  # return code should be 1
251        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
252        self.assertTrue(errors)
253
254    def testMethodFileNo(self):
255        r = self.connection.fileno()
256        self.assertIsInstance(r, int)
257        self.assertGreaterEqual(r, 0)
258
259
[539]260class TestSimpleQueries(unittest.TestCase):
[641]261    """Test simple queries via a basic pg connection."""
[539]262
263    def setUp(self):
264        self.c = connect()
265
266    def tearDown(self):
267        self.c.close()
268
[597]269    def testClassName(self):
270        r = self.c.query("select 1")
271        self.assertEqual(r.__class__.__name__, 'Query')
272
273    def testModuleName(self):
274        r = self.c.query("select 1")
275        self.assertEqual(r.__class__.__module__, 'pg')
276
[596]277    def testStr(self):
278        q = ("select 1 as a, 'hello' as h, 'w' as world"
279            " union select 2, 'xyz', 'uvw'")
280        r = self.c.query(q)
281        self.assertEqual(str(r),
282            'a|  h  |world\n'
283            '-+-----+-----\n'
284            '1|hello|w    \n'
285            '2|xyz  |uvw  \n'
286            '(2 rows)')
287
288    def testRepr(self):
289        r = repr(self.c.query("select 1"))
[597]290        self.assertTrue(r.startswith('<pg.Query object'), r)
[596]291
[539]292    def testSelect0(self):
293        q = "select 0"
294        self.c.query(q)
295
296    def testSelect0Semicolon(self):
297        q = "select 0;"
298        self.c.query(q)
299
300    def testSelectDotSemicolon(self):
301        q = "select .;"
302        self.assertRaises(pg.ProgrammingError, self.c.query, q)
303
304    def testGetresult(self):
305        q = "select 0"
306        result = [(0,)]
307        r = self.c.query(q).getresult()
[542]308        self.assertIsInstance(r, list)
309        v = r[0]
310        self.assertIsInstance(v, tuple)
311        self.assertIsInstance(v[0], int)
[539]312        self.assertEqual(r, result)
313
[542]314    def testGetresultLong(self):
[571]315        q = "select 9876543210"
316        result = long(9876543210)
317        self.assertIsInstance(result, long)
[542]318        v = self.c.query(q).getresult()[0][0]
319        self.assertIsInstance(v, long)
320        self.assertEqual(v, result)
321
[571]322    def testGetresultDecimal(self):
323        q = "select 98765432109876543210"
324        result = Decimal(98765432109876543210)
325        v = self.c.query(q).getresult()[0][0]
326        self.assertIsInstance(v, Decimal)
327        self.assertEqual(v, result)
328
[542]329    def testGetresultString(self):
330        result = 'Hello, world!'
331        q = "select '%s'" % result
332        v = self.c.query(q).getresult()[0][0]
333        self.assertIsInstance(v, str)
334        self.assertEqual(v, result)
335
[539]336    def testDictresult(self):
337        q = "select 0 as alias0"
338        result = [{'alias0': 0}]
339        r = self.c.query(q).dictresult()
[542]340        self.assertIsInstance(r, list)
341        v = r[0]
342        self.assertIsInstance(v, dict)
343        self.assertIsInstance(v['alias0'], int)
[539]344        self.assertEqual(r, result)
345
[542]346    def testDictresultLong(self):
[571]347        q = "select 9876543210 as longjohnsilver"
348        result = long(9876543210)
349        self.assertIsInstance(result, long)
[542]350        v = self.c.query(q).dictresult()[0]['longjohnsilver']
351        self.assertIsInstance(v, long)
352        self.assertEqual(v, result)
353
[571]354    def testDictresultDecimal(self):
355        q = "select 98765432109876543210 as longjohnsilver"
356        result = Decimal(98765432109876543210)
357        v = self.c.query(q).dictresult()[0]['longjohnsilver']
358        self.assertIsInstance(v, Decimal)
359        self.assertEqual(v, result)
360
[542]361    def testDictresultString(self):
362        result = 'Hello, world!'
363        q = "select '%s' as greeting" % result
364        v = self.c.query(q).dictresult()[0]['greeting']
365        self.assertIsInstance(v, str)
366        self.assertEqual(v, result)
367
[539]368    def testNamedresult(self):
[546]369        q = "select 0 as alias0"
370        result = [(0,)]
371        r = self.c.query(q).namedresult()
372        self.assertEqual(r, result)
373        v = r[0]
374        self.assertEqual(v._fields, ('alias0',))
375        self.assertEqual(v.alias0, 0)
[539]376
377    def testGet3Cols(self):
378        q = "select 1,2,3"
379        result = [(1, 2, 3)]
380        r = self.c.query(q).getresult()
381        self.assertEqual(r, result)
382
383    def testGet3DictCols(self):
384        q = "select 1 as a,2 as b,3 as c"
385        result = [dict(a=1, b=2, c=3)]
386        r = self.c.query(q).dictresult()
387        self.assertEqual(r, result)
388
389    def testGet3NamedCols(self):
390        q = "select 1 as a,2 as b,3 as c"
391        result = [(1, 2, 3)]
392        r = self.c.query(q).namedresult()
393        self.assertEqual(r, result)
394        v = r[0]
395        self.assertEqual(v._fields, ('a', 'b', 'c'))
396        self.assertEqual(v.b, 2)
397
398    def testGet3Rows(self):
399        q = "select 3 union select 1 union select 2 order by 1"
400        result = [(1,), (2,), (3,)]
401        r = self.c.query(q).getresult()
402        self.assertEqual(r, result)
403
404    def testGet3DictRows(self):
405        q = ("select 3 as alias3"
406            " union select 1 union select 2 order by 1")
407        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
408        r = self.c.query(q).dictresult()
409        self.assertEqual(r, result)
410
411    def testGet3NamedRows(self):
412        q = ("select 3 as alias3"
413            " union select 1 union select 2 order by 1")
414        result = [(1,), (2,), (3,)]
415        r = self.c.query(q).namedresult()
416        self.assertEqual(r, result)
417        for v in r:
418            self.assertEqual(v._fields, ('alias3',))
419
420    def testDictresultNames(self):
421        q = "select 'MixedCase' as MixedCaseAlias"
422        result = [{'mixedcasealias': 'MixedCase'}]
423        r = self.c.query(q).dictresult()
424        self.assertEqual(r, result)
425        q = "select 'MixedCase' as \"MixedCaseAlias\""
426        result = [{'MixedCaseAlias': 'MixedCase'}]
427        r = self.c.query(q).dictresult()
428        self.assertEqual(r, result)
429
430    def testNamedresultNames(self):
431        q = "select 'MixedCase' as MixedCaseAlias"
432        result = [('MixedCase',)]
433        r = self.c.query(q).namedresult()
434        self.assertEqual(r, result)
435        v = r[0]
436        self.assertEqual(v._fields, ('mixedcasealias',))
437        self.assertEqual(v.mixedcasealias, 'MixedCase')
438        q = "select 'MixedCase' as \"MixedCaseAlias\""
439        r = self.c.query(q).namedresult()
440        self.assertEqual(r, result)
441        v = r[0]
442        self.assertEqual(v._fields, ('MixedCaseAlias',))
443        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
444
445    def testBigGetresult(self):
446        num_cols = 100
447        num_rows = 100
[585]448        q = "select " + ','.join(map(str, range(num_cols)))
[539]449        q = ' union all '.join((q,) * num_rows)
450        r = self.c.query(q).getresult()
451        result = [tuple(range(num_cols))] * num_rows
452        self.assertEqual(r, result)
453
454    def testListfields(self):
455        q = ('select 0 as a, 0 as b, 0 as c,'
456            ' 0 as c, 0 as b, 0 as a,'
457            ' 0 as lowercase, 0 as UPPERCASE,'
458            ' 0 as MixedCase, 0 as "MixedCase",'
459            ' 0 as a_long_name_with_underscores,'
460            ' 0 as "A long name with Blanks"')
461        r = self.c.query(q).listfields()
462        result = ('a', 'b', 'c', 'c', 'b', 'a',
463            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
464            'a_long_name_with_underscores',
465            'A long name with Blanks')
466        self.assertEqual(r, result)
467
468    def testFieldname(self):
469        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
470        r = self.c.query(q).fieldname(2)
[542]471        self.assertEqual(r, 'x')
472        r = self.c.query(q).fieldname(3)
473        self.assertEqual(r, 'y')
[539]474
475    def testFieldnum(self):
[542]476        q = "select 1 as x"
477        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
478        q = "select 1 as x"
479        r = self.c.query(q).fieldnum('x')
480        self.assertIsInstance(r, int)
481        self.assertEqual(r, 0)
[539]482        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
[542]483        r = self.c.query(q).fieldnum('x')
484        self.assertIsInstance(r, int)
485        self.assertEqual(r, 2)
486        r = self.c.query(q).fieldnum('y')
487        self.assertIsInstance(r, int)
488        self.assertEqual(r, 3)
[539]489
490    def testNtuples(self):
[542]491        q = "select 1 where false"
492        r = self.c.query(q).ntuples()
493        self.assertIsInstance(r, int)
494        self.assertEqual(r, 0)
[539]495        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
496            " union select 5 as a, 6 as b, 7 as c, 8 as d")
497        r = self.c.query(q).ntuples()
[542]498        self.assertIsInstance(r, int)
499        self.assertEqual(r, 2)
500        q = ("select 1 union select 2 union select 3"
501            " union select 4 union select 5 union select 6")
502        r = self.c.query(q).ntuples()
503        self.assertIsInstance(r, int)
504        self.assertEqual(r, 6)
[539]505
506    def testQuery(self):
507        query = self.c.query
508        query("drop table if exists test_table")
509        q = "create table test_table (n integer) with oids"
510        r = query(q)
511        self.assertIsNone(r)
512        q = "insert into test_table values (1)"
513        r = query(q)
[542]514        self.assertIsInstance(r, int)
[539]515        q = "insert into test_table select 2"
516        r = query(q)
[542]517        self.assertIsInstance(r, int)
[539]518        oid = r
519        q = "select oid from test_table where n=2"
520        r = query(q).getresult()
521        self.assertEqual(len(r), 1)
522        r = r[0]
523        self.assertEqual(len(r), 1)
524        r = r[0]
[542]525        self.assertIsInstance(r, int)
[539]526        self.assertEqual(r, oid)
527        q = "insert into test_table select 3 union select 4 union select 5"
528        r = query(q)
[542]529        self.assertIsInstance(r, str)
[539]530        self.assertEqual(r, '3')
531        q = "update test_table set n=4 where n<5"
532        r = query(q)
[542]533        self.assertIsInstance(r, str)
[539]534        self.assertEqual(r, '4')
535        q = "delete from test_table"
536        r = query(q)
[542]537        self.assertIsInstance(r, str)
[539]538        self.assertEqual(r, '5')
539        query("drop table test_table")
540
541
[628]542class TestUnicodeQueries(unittest.TestCase):
[641]543    """Test unicode strings as queries via a basic pg connection."""
[628]544
545    def setUp(self):
546        self.c = connect()
547        self.c.query('set client_encoding=utf8')
548
549    def tearDown(self):
550        self.c.close()
551
552    def testGetresulAscii(self):
553        result = u'Hello, world!'
554        q = u"select '%s'" % result
555        v = self.c.query(q).getresult()[0][0]
556        self.assertIsInstance(v, str)
557        self.assertEqual(v, result)
558
559    def testDictresulAscii(self):
560        result = u'Hello, world!'
561        q = u"select '%s' as greeting" % result
562        v = self.c.query(q).dictresult()[0]['greeting']
563        self.assertIsInstance(v, str)
564        self.assertEqual(v, result)
565
566    def testGetresultUtf8(self):
567        result = u'Hello, wörld & ЌОр!'
568        q = u"select '%s'" % result
569        if not unicode_strings:
570            result = result.encode('utf8')
571        # pass the query as unicode
572        try:
573            v = self.c.query(q).getresult()[0][0]
574        except pg.ProgrammingError:
575            self.skipTest("database does not support utf8")
576        self.assertIsInstance(v, str)
577        self.assertEqual(v, result)
578        q = q.encode('utf8')
579        # pass the query as bytes
580        v = self.c.query(q).getresult()[0][0]
581        self.assertIsInstance(v, str)
582        self.assertEqual(v, result)
583
584    def testDictresultUtf8(self):
585        result = u'Hello, wörld & ЌОр!'
586        q = u"select '%s' as greeting" % result
587        if not unicode_strings:
588            result = result.encode('utf8')
589        try:
590            v = self.c.query(q).dictresult()[0]['greeting']
591        except pg.ProgrammingError:
592            self.skipTest("database does not support utf8")
593        self.assertIsInstance(v, str)
594        self.assertEqual(v, result)
595        q = q.encode('utf8')
596        v = self.c.query(q).dictresult()[0]['greeting']
597        self.assertIsInstance(v, str)
598        self.assertEqual(v, result)
599
600    def testDictresultLatin1(self):
601        try:
602            self.c.query('set client_encoding=latin1')
603        except pg.ProgrammingError:
604            self.skipTest("database does not support latin1")
605        result = u'Hello, wörld!'
606        q = u"select '%s'" % result
607        if not unicode_strings:
608            result = result.encode('latin1')
609        v = self.c.query(q).getresult()[0][0]
610        self.assertIsInstance(v, str)
611        self.assertEqual(v, result)
612        q = q.encode('latin1')
613        v = self.c.query(q).getresult()[0][0]
614        self.assertIsInstance(v, str)
615        self.assertEqual(v, result)
616
617    def testDictresultLatin1(self):
618        try:
619            self.c.query('set client_encoding=latin1')
620        except pg.ProgrammingError:
621            self.skipTest("database does not support latin1")
622        result = u'Hello, wörld!'
623        q = u"select '%s' as greeting" % result
624        if not unicode_strings:
625            result = result.encode('latin1')
626        v = self.c.query(q).dictresult()[0]['greeting']
627        self.assertIsInstance(v, str)
628        self.assertEqual(v, result)
629        q = q.encode('latin1')
630        v = self.c.query(q).dictresult()[0]['greeting']
631        self.assertIsInstance(v, str)
632        self.assertEqual(v, result)
633
634    def testGetresultCyrillic(self):
635        try:
636            self.c.query('set client_encoding=iso_8859_5')
637        except pg.ProgrammingError:
638            self.skipTest("database does not support cyrillic")
639        result = u'Hello, ЌОр!'
640        q = u"select '%s'" % result
641        if not unicode_strings:
642            result = result.encode('cyrillic')
643        v = self.c.query(q).getresult()[0][0]
644        self.assertIsInstance(v, str)
645        self.assertEqual(v, result)
646        q = q.encode('cyrillic')
647        v = self.c.query(q).getresult()[0][0]
648        self.assertIsInstance(v, str)
649        self.assertEqual(v, result)
650
651    def testDictresultCyrillic(self):
652        try:
653            self.c.query('set client_encoding=iso_8859_5')
654        except pg.ProgrammingError:
655            self.skipTest("database does not support cyrillic")
656        result = u'Hello, ЌОр!'
657        q = u"select '%s' as greeting" % result
658        if not unicode_strings:
659            result = result.encode('cyrillic')
660        v = self.c.query(q).dictresult()[0]['greeting']
661        self.assertIsInstance(v, str)
662        self.assertEqual(v, result)
663        q = q.encode('cyrillic')
664        v = self.c.query(q).dictresult()[0]['greeting']
665        self.assertIsInstance(v, str)
666        self.assertEqual(v, result)
667
668    def testGetresultLatin9(self):
669        try:
670            self.c.query('set client_encoding=latin9')
671        except pg.ProgrammingError:
672            self.skipTest("database does not support latin9")
673        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
674        q = u"select '%s'" % result
675        if not unicode_strings:
676            result = result.encode('latin9')
677        v = self.c.query(q).getresult()[0][0]
678        self.assertIsInstance(v, str)
679        self.assertEqual(v, result)
680        q = q.encode('latin9')
681        v = self.c.query(q).getresult()[0][0]
682        self.assertIsInstance(v, str)
683        self.assertEqual(v, result)
684
685    def testDictresultLatin9(self):
686        try:
687            self.c.query('set client_encoding=latin9')
688        except pg.ProgrammingError:
689            self.skipTest("database does not support latin9")
690        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
691        q = u"select '%s' as menu" % result
692        if not unicode_strings:
693            result = result.encode('latin9')
694        v = self.c.query(q).dictresult()[0]['menu']
695        self.assertIsInstance(v, str)
696        self.assertEqual(v, result)
697        q = q.encode('latin9')
698        v = self.c.query(q).dictresult()[0]['menu']
699        self.assertIsInstance(v, str)
700        self.assertEqual(v, result)
701
702
[539]703class TestParamQueries(unittest.TestCase):
[641]704    """Test queries with parameters via a basic pg connection."""
[539]705
706    def setUp(self):
707        self.c = connect()
[613]708        self.c.query('set client_encoding=utf8')
[539]709
710    def tearDown(self):
711        self.c.close()
712
713    def testQueryWithNoneParam(self):
[655]714        self.assertRaises(TypeError, self.c.query, "select $1", None)
715        self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None)
[539]716        self.assertEqual(self.c.query("select $1::integer", (None,)
717            ).getresult(), [(None,)])
718        self.assertEqual(self.c.query("select $1::text", [None]
719            ).getresult(), [(None,)])
[655]720        self.assertEqual(self.c.query("select $1::text", [[None]]
721            ).getresult(), [(None,)])
[539]722
[634]723    def testQueryWithBoolParams(self, use_bool=None):
[539]724        query = self.c.query
[634]725        if use_bool is not None:
726            use_bool_default = pg.get_bool()
727            pg.set_bool(use_bool)
728        try:
729            v_false, v_true = (False, True) if use_bool else 'ft'
730            r_false, r_true = [(v_false,)], [(v_true,)]
731            self.assertEqual(query("select false").getresult(), r_false)
732            self.assertEqual(query("select true").getresult(), r_true)
733            q = "select $1::bool"
734            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
735            self.assertEqual(query(q, ('f',)).getresult(), r_false)
736            self.assertEqual(query(q, ('t',)).getresult(), r_true)
737            self.assertEqual(query(q, ('false',)).getresult(), r_false)
738            self.assertEqual(query(q, ('true',)).getresult(), r_true)
739            self.assertEqual(query(q, ('n',)).getresult(), r_false)
740            self.assertEqual(query(q, ('y',)).getresult(), r_true)
741            self.assertEqual(query(q, (0,)).getresult(), r_false)
742            self.assertEqual(query(q, (1,)).getresult(), r_true)
743            self.assertEqual(query(q, (False,)).getresult(), r_false)
744            self.assertEqual(query(q, (True,)).getresult(), r_true)
745        finally:
746            if use_bool is not None:
747                pg.set_bool(use_bool_default)
[539]748
[634]749    def testQueryWithBoolParamsAndUseBool(self):
750        self.testQueryWithBoolParams(use_bool=True)
751
[539]752    def testQueryWithIntParams(self):
753        query = self.c.query
754        self.assertEqual(query("select 1+1").getresult(), [(2,)])
755        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
756        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
757        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
758        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
759        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
760            [(Decimal('2'),)])
761        self.assertEqual(query("select 1, $1::integer", (2,)
762            ).getresult(), [(1, 2)])
[690]763        self.assertEqual(query("select 1 union select $1::integer", (2,)
[539]764            ).getresult(), [(1,), (2,)])
765        self.assertEqual(query("select $1::integer+$2", (1, 2)
766            ).getresult(), [(3,)])
767        self.assertEqual(query("select $1::integer+$2", [1, 2]
768            ).getresult(), [(3,)])
[585]769        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))
[539]770            ).getresult(), [(15,)])
771
772    def testQueryWithStrParams(self):
773        query = self.c.query
774        self.assertEqual(query("select $1||', world!'", ('Hello',)
775            ).getresult(), [('Hello, world!',)])
776        self.assertEqual(query("select $1||', world!'", ['Hello']
777            ).getresult(), [('Hello, world!',)])
778        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
779            ).getresult(), [('Hello, world!',)])
780        self.assertEqual(query("select $1::text", ('Hello, world!',)
781            ).getresult(), [('Hello, world!',)])
782        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
783            ).getresult(), [('Hello', 'world')])
784        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
785            ).getresult(), [('Hello', 'world')])
786        self.assertEqual(query("select $1::text union select $2::text",
787            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
788        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
789            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
790
791    def testQueryWithUnicodeParams(self):
792        query = self.c.query
[613]793        try:
794            query('set client_encoding=utf8')
795            query("select 'wörld'").getresult()[0][0] == 'wörld'
796        except pg.ProgrammingError:
797            self.skipTest("database does not support utf8")
[539]798        self.assertEqual(query("select $1||', '||$2||'!'",
[585]799            ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)])
[613]800
801    def testQueryWithUnicodeParamsLatin1(self):
802        query = self.c.query
803        try:
804            query('set client_encoding=latin1')
805            query("select 'wörld'").getresult()[0][0] == 'wörld'
806        except pg.ProgrammingError:
807            self.skipTest("database does not support latin1")
[585]808        r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult()
809        if unicode_strings:
810            self.assertEqual(r, [('Hello, wörld!',)])
811        else:
812            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
[539]813        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
[585]814            ('Hello', u'ЌОр'))
[613]815        query('set client_encoding=iso_8859_1')
816        r = query("select $1||', '||$2||'!'",
817            ('Hello', u'wörld')).getresult()
[585]818        if unicode_strings:
819            self.assertEqual(r, [('Hello, wörld!',)])
820        else:
821            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
[539]822        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
[585]823            ('Hello', u'ЌОр'))
[613]824        query('set client_encoding=sql_ascii')
[539]825        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
[585]826            ('Hello', u'wörld'))
[613]827
828    def testQueryWithUnicodeParamsCyrillic(self):
829        query = self.c.query
830        try:
831            query('set client_encoding=iso_8859_5')
832            query("select 'ЌОр'").getresult()[0][0] == 'ЌОр'
833        except pg.ProgrammingError:
834            self.skipTest("database does not support cyrillic")
835        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
836            ('Hello', u'wörld'))
837        r = query("select $1||', '||$2||'!'",
838            ('Hello', u'ЌОр')).getresult()
[585]839        if unicode_strings:
840            self.assertEqual(r, [('Hello, ЌОр!',)])
841        else:
842            self.assertEqual(r, [(u'Hello, ЌОр!'.encode('cyrillic'),)])
[613]843        query('set client_encoding=sql_ascii')
[539]844        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
[613]845            ('Hello', u'ЌОр!'))
[539]846
847    def testQueryWithMixedParams(self):
848        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
849            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
850        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
851            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
852
853    def testQueryWithDuplicateParams(self):
854        self.assertRaises(pg.ProgrammingError,
855            self.c.query, "select $1+$1", (1,))
856        self.assertRaises(pg.ProgrammingError,
857            self.c.query, "select $1+$1", (1, 2))
858
859    def testQueryWithZeroParams(self):
860        self.assertEqual(self.c.query("select 1+1", []
861            ).getresult(), [(2,)])
862
863    def testQueryWithGarbage(self):
864        garbage = r"'\{}+()-#[]oo324"
865        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
866            ).dictresult(), [{'garbage': garbage}])
867
868
869class TestInserttable(unittest.TestCase):
[641]870    """Test inserttable method."""
[539]871
872    @classmethod
873    def setUpClass(cls):
874        c = connect()
875        c.query("drop table if exists test cascade")
876        c.query("create table test ("
877            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
878            "d numeric, f4 real, f8 double precision, m money,"
879            "c char(1), v4 varchar(4), c4 char(4), t text)")
[613]880        # Check whether the test database uses SQL_ASCII - this means
881        # that it does not consider encoding when calculating lengths.
882        c.query("set client_encoding=utf8")
883        cls.has_encoding = c.query(
884            "select length('À') - length('a')").getresult()[0][0] == 0
[539]885        c.close()
886
887    @classmethod
888    def tearDownClass(cls):
889        c = connect()
890        c.query("drop table test cascade")
891        c.close()
892
893    def setUp(self):
894        self.c = connect()
[610]895        self.c.query("set client_encoding=utf8")
896        self.c.query("set datestyle='ISO,YMD'")
[576]897        self.c.query("set lc_monetary='C'")
[539]898
899    def tearDown(self):
[548]900        self.c.query("truncate table test")
[539]901        self.c.close()
902
903    data = [
[571]904        (-1, -1, long(-1), True, '1492-10-12', '08:30:00',
[539]905            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
[571]906        (0, 0, long(0), False, '1607-04-14', '09:00:00',
[539]907            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
[571]908        (1, 1, long(1), True, '1801-03-04', '03:45:00',
[539]909            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
[571]910        (2, 2, long(2), False, '1903-12-17', '11:22:00',
[539]911            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
912
[613]913    @classmethod
914    def db_len(cls, s, encoding):
915        if cls.has_encoding:
916            s = s if isinstance(s, unicode) else s.decode(encoding)
917        else:
918            s = s.encode(encoding) if isinstance(s, unicode) else s
919        return len(s)
920
921    def get_back(self, encoding='utf-8'):
[539]922        """Convert boolean and decimal values back."""
923        data = []
924        for row in self.c.query("select * from test order by 1").getresult():
925            self.assertIsInstance(row, tuple)
926            row = list(row)
927            if row[0] is not None:  # smallint
928                self.assertIsInstance(row[0], int)
929            if row[1] is not None:  # integer
930                self.assertIsInstance(row[1], int)
931            if row[2] is not None:  # bigint
932                self.assertIsInstance(row[2], long)
933            if row[3] is not None:  # boolean
934                self.assertIsInstance(row[3], str)
935                row[3] = {'f': False, 't': True}.get(row[3])
936            if row[4] is not None:  # date
937                self.assertIsInstance(row[4], str)
938                self.assertTrue(row[4].replace('-', '').isdigit())
939            if row[5] is not None:  # time
940                self.assertIsInstance(row[5], str)
941                self.assertTrue(row[5].replace(':', '').isdigit())
942            if row[6] is not None:  # numeric
943                self.assertIsInstance(row[6], Decimal)
944                row[6] = float(row[6])
945            if row[7] is not None:  # real
946                self.assertIsInstance(row[7], float)
947            if row[8] is not None:  # double precision
948                self.assertIsInstance(row[8], float)
949                row[8] = float(row[8])
950            if row[9] is not None:  # money
951                self.assertIsInstance(row[9], Decimal)
952                row[9] = str(float(row[9]))
953            if row[10] is not None:  # char(1)
954                self.assertIsInstance(row[10], str)
[613]955                self.assertEqual(self.db_len(row[10], encoding), 1)
[539]956            if row[11] is not None:  # varchar(4)
957                self.assertIsInstance(row[11], str)
[613]958                self.assertLessEqual(self.db_len(row[11], encoding), 4)
[539]959            if row[12] is not None:  # char(4)
960                self.assertIsInstance(row[12], str)
[613]961                self.assertEqual(self.db_len(row[12], encoding), 4)
[539]962                row[12] = row[12].rstrip()
963            if row[13] is not None:  # text
964                self.assertIsInstance(row[13], str)
965            row = tuple(row)
966            data.append(row)
967        return data
968
969    def testInserttable1Row(self):
970        data = self.data[2:3]
[613]971        self.c.inserttable('test', data)
[539]972        self.assertEqual(self.get_back(), data)
973
974    def testInserttable4Rows(self):
975        data = self.data
[613]976        self.c.inserttable('test', data)
[539]977        self.assertEqual(self.get_back(), data)
978
979    def testInserttableMultipleRows(self):
980        num_rows = 100
981        data = self.data[2:3] * num_rows
[613]982        self.c.inserttable('test', data)
[539]983        r = self.c.query("select count(*) from test").getresult()[0][0]
984        self.assertEqual(r, num_rows)
985
986    def testInserttableMultipleCalls(self):
987        num_rows = 10
988        data = self.data[2:3]
989        for _i in range(num_rows):
[613]990            self.c.inserttable('test', data)
[539]991        r = self.c.query("select count(*) from test").getresult()[0][0]
992        self.assertEqual(r, num_rows)
993
994    def testInserttableNullValues(self):
995        data = [(None,) * 14] * 100
[613]996        self.c.inserttable('test', data)
[539]997        self.assertEqual(self.get_back(), data)
998
999    def testInserttableMaxValues(self):
1000        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
1001            True, '2999-12-31', '11:59:59', 1e99,
1002            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
1003            "1", "1234", "1234", "1234" * 100)]
[613]1004        self.c.inserttable('test', data)
[539]1005        self.assertEqual(self.get_back(), data)
1006
[613]1007    def testInserttableByteValues(self):
1008        try:
1009            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1010        except pg.ProgrammingError:
1011            self.skipTest("database does not support utf8")
1012        # non-ascii chars do not fit in char(1) when there is no encoding
1013        c = u'€' if self.has_encoding else u'$'
1014        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1015            0.0, 0.0, 0.0, u'0.0',
1016            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1017        row_bytes = tuple(s.encode('utf-8')
1018            if isinstance(s, unicode) else s for s in row_unicode)
1019        data = [row_bytes] * 2
1020        self.c.inserttable('test', data)
1021        if unicode_strings:
1022            data = [row_unicode] * 2
1023        self.assertEqual(self.get_back(), data)
[539]1024
[613]1025    def testInserttableUnicodeUtf8(self):
1026        try:
1027            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1028        except pg.ProgrammingError:
1029            self.skipTest("database does not support utf8")
1030        # non-ascii chars do not fit in char(1) when there is no encoding
1031        c = u'€' if self.has_encoding else u'$'
1032        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1033            0.0, 0.0, 0.0, u'0.0',
1034            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1035        data = [row_unicode] * 2
1036        self.c.inserttable('test', data)
1037        if not unicode_strings:
1038            row_bytes = tuple(s.encode('utf-8')
1039                if isinstance(s, unicode) else s for s in row_unicode)
1040            data = [row_bytes] * 2
1041        self.assertEqual(self.get_back(), data)
1042
1043    def testInserttableUnicodeLatin1(self):
1044
1045        try:
1046            self.c.query("set client_encoding=latin1")
1047            self.c.query("select 'Â¥'")
1048        except pg.ProgrammingError:
1049            self.skipTest("database does not support latin1")
1050        # non-ascii chars do not fit in char(1) when there is no encoding
1051        c = u'€' if self.has_encoding else u'$'
1052        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1053            0.0, 0.0, 0.0, u'0.0',
1054            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1055        data = [row_unicode]
1056        # cannot encode € sign with latin1 encoding
1057        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1058        row_unicode = tuple(s.replace(u'€', u'Â¥')
1059            if isinstance(s, unicode) else s for s in row_unicode)
1060        data = [row_unicode] * 2
1061        self.c.inserttable('test', data)
1062        if not unicode_strings:
1063            row_bytes = tuple(s.encode('latin1')
1064                if isinstance(s, unicode) else s for s in row_unicode)
1065            data = [row_bytes] * 2
1066        self.assertEqual(self.get_back('latin1'), data)
1067
1068    def testInserttableUnicodeLatin9(self):
1069        try:
1070            self.c.query("set client_encoding=latin9")
1071            self.c.query("select '€'")
1072        except pg.ProgrammingError:
1073            self.skipTest("database does not support latin9")
1074            return
1075        # non-ascii chars do not fit in char(1) when there is no encoding
1076        c = u'€' if self.has_encoding else u'$'
1077        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1078            0.0, 0.0, 0.0, u'0.0',
1079            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1080        data = [row_unicode] * 2
1081        self.c.inserttable('test', data)
1082        if not unicode_strings:
1083            row_bytes = tuple(s.encode('latin9')
1084                if isinstance(s, unicode) else s for s in row_unicode)
1085            data = [row_bytes] * 2
1086        self.assertEqual(self.get_back('latin9'), data)
1087
1088    def testInserttableNoEncoding(self):
1089        self.c.query("set client_encoding=sql_ascii")
1090        # non-ascii chars do not fit in char(1) when there is no encoding
1091        c = u'€' if self.has_encoding else u'$'
1092        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1093            0.0, 0.0, 0.0, u'0.0',
1094            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1095        data = [row_unicode]
1096        # cannot encode non-ascii unicode without a specific encoding
1097        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1098
1099
[548]1100class TestDirectSocketAccess(unittest.TestCase):
[641]1101    """Test copy command with direct socket access."""
[548]1102
1103    @classmethod
1104    def setUpClass(cls):
1105        c = connect()
1106        c.query("drop table if exists test cascade")
1107        c.query("create table test (i int, v varchar(16))")
1108        c.close()
1109
1110    @classmethod
1111    def tearDownClass(cls):
1112        c = connect()
1113        c.query("drop table test cascade")
1114        c.close()
1115
1116    def setUp(self):
1117        self.c = connect()
[610]1118        self.c.query("set client_encoding=utf8")
[548]1119
1120    def tearDown(self):
1121        self.c.query("truncate table test")
1122        self.c.close()
1123
1124    def testPutline(self):
1125        putline = self.c.putline
1126        query = self.c.query
1127        data = list(enumerate("apple pear plum cherry banana".split()))
1128        query("copy test from stdin")
1129        try:
1130            for i, v in data:
1131                putline("%d\t%s\n" % (i, v))
1132            putline("\\.\n")
1133        finally:
1134            self.c.endcopy()
1135        r = query("select * from test").getresult()
1136        self.assertEqual(r, data)
1137
[588]1138    def testPutlineBytesAndUnicode(self):
1139        putline = self.c.putline
1140        query = self.c.query
1141        query("copy test from stdin")
1142        try:
1143            putline(u"47\tkÀse\n".encode('utf8'))
1144            putline("35\twÃŒrstel\n")
1145            putline(b"\\.\n")
1146        finally:
1147            self.c.endcopy()
1148        r = query("select * from test").getresult()
1149        self.assertEqual(r, [(47, 'kÀse'), (35, 'wÃŒrstel')])
1150
[587]1151    def testGetline(self):
[548]1152        getline = self.c.getline
1153        query = self.c.query
1154        data = list(enumerate("apple banana pear plum strawberry".split()))
1155        n = len(data)
1156        self.c.inserttable('test', data)
1157        query("copy test to stdout")
1158        try:
1159            for i in range(n + 2):
1160                v = getline()
1161                if i < n:
1162                    self.assertEqual(v, '%d\t%s' % data[i])
1163                elif i == n:
1164                    self.assertEqual(v, '\\.')
1165                else:
1166                    self.assertIsNone(v)
1167        finally:
1168            try:
1169                self.c.endcopy()
1170            except IOError:
1171                pass
1172
[588]1173    def testGetlineBytesAndUnicode(self):
1174        getline = self.c.getline
1175        query = self.c.query
1176        data = [(54, u'kÀse'.encode('utf8')), (73, u'wÃŒrstel')]
1177        self.c.inserttable('test', data)
1178        query("copy test to stdout")
1179        try:
1180            v = getline()
1181            self.assertIsInstance(v, str)
1182            self.assertEqual(v, '54\tkÀse')
1183            v = getline()
1184            self.assertIsInstance(v, str)
1185            self.assertEqual(v, '73\twÃŒrstel')
1186            self.assertEqual(getline(), '\\.')
1187            self.assertIsNone(getline())
1188        finally:
1189            try:
1190                self.c.endcopy()
1191            except IOError:
1192                pass
1193
[548]1194    def testParameterChecks(self):
1195        self.assertRaises(TypeError, self.c.putline)
1196        self.assertRaises(TypeError, self.c.getline, 'invalid')
1197        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
1198
1199
1200class TestNotificatons(unittest.TestCase):
[641]1201    """Test notification support."""
[539]1202
1203    def setUp(self):
1204        self.c = connect()
1205
1206    def tearDown(self):
1207        self.c.close()
1208
[547]1209    def testGetNotify(self):
[548]1210        getnotify = self.c.getnotify
[547]1211        query = self.c.query
1212        self.assertIsNone(getnotify())
1213        query('listen test_notify')
1214        try:
1215            self.assertIsNone(self.c.getnotify())
1216            query("notify test_notify")
1217            r = getnotify()
1218            self.assertIsInstance(r, tuple)
1219            self.assertEqual(len(r), 3)
1220            self.assertIsInstance(r[0], str)
1221            self.assertIsInstance(r[1], int)
1222            self.assertIsInstance(r[2], str)
1223            self.assertEqual(r[0], 'test_notify')
1224            self.assertEqual(r[2], '')
1225            self.assertIsNone(self.c.getnotify())
[690]1226            query("notify test_notify, 'test_payload'")
1227            r = getnotify()
1228            self.assertTrue(isinstance(r, tuple))
1229            self.assertEqual(len(r), 3)
1230            self.assertIsInstance(r[0], str)
1231            self.assertIsInstance(r[1], int)
1232            self.assertIsInstance(r[2], str)
1233            self.assertEqual(r[0], 'test_notify')
1234            self.assertEqual(r[2], 'test_payload')
1235            self.assertIsNone(getnotify())
[547]1236        finally:
1237            query('unlisten test_notify')
1238
[539]1239    def testGetNoticeReceiver(self):
1240        self.assertIsNone(self.c.get_notice_receiver())
1241
1242    def testSetNoticeReceiver(self):
1243        self.assertRaises(TypeError, self.c.set_notice_receiver, None)
1244        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
1245        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
1246
1247    def testSetAndGetNoticeReceiver(self):
1248        r = lambda notice: None
1249        self.assertIsNone(self.c.set_notice_receiver(r))
1250        self.assertIs(self.c.get_notice_receiver(), r)
1251
1252    def testNoticeReceiver(self):
1253        self.c.query('''create function bilbo_notice() returns void AS $$
1254            begin
1255                raise warning 'Bilbo was here!';
1256            end;
1257            $$ language plpgsql''')
1258        try:
1259            received = {}
1260
1261            def notice_receiver(notice):
1262                for attr in dir(notice):
[562]1263                    if attr.startswith('__'):
1264                        continue
[539]1265                    value = getattr(notice, attr)
1266                    if isinstance(value, str):
1267                        value = value.replace('WARNUNG', 'WARNING')
1268                    received[attr] = value
1269
1270            self.c.set_notice_receiver(notice_receiver)
1271            self.c.query('''select bilbo_notice()''')
1272            self.assertEqual(received, dict(
1273                pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
1274                severity='WARNING', primary='Bilbo was here!',
1275                detail=None, hint=None))
1276        finally:
1277            self.c.query('''drop function bilbo_notice();''')
1278
1279
[546]1280class TestConfigFunctions(unittest.TestCase):
1281    """Test the functions for changing default settings.
1282
1283    To test the effect of most of these functions, we need a database
1284    connection.  That's why they are covered in this test module.
1285
1286    """
1287
1288    def setUp(self):
1289        self.c = connect()
[608]1290        self.c.query("set client_encoding=utf8")
1291        self.c.query("set lc_monetary='C'")
[546]1292
1293    def tearDown(self):
1294        self.c.close()
1295
1296    def testGetDecimalPoint(self):
1297        point = pg.get_decimal_point()
[616]1298        # error if a parameter is passed
1299        self.assertRaises(TypeError, pg.get_decimal_point, point)
[546]1300        self.assertIsInstance(point, str)
[616]1301        self.assertEqual(point, '.')  # the default setting
[615]1302        pg.set_decimal_point(',')
1303        try:
1304            r = pg.get_decimal_point()
1305        finally:
1306            pg.set_decimal_point(point)
1307        self.assertIsInstance(r, str)
1308        self.assertEqual(r, ',')
[619]1309        pg.set_decimal_point("'")
1310        try:
1311            r = pg.get_decimal_point()
1312        finally:
1313            pg.set_decimal_point(point)
1314        self.assertIsInstance(r, str)
1315        self.assertEqual(r, "'")
1316        pg.set_decimal_point('')
1317        try:
1318            r = pg.get_decimal_point()
1319        finally:
1320            pg.set_decimal_point(point)
1321        self.assertIsNone(r)
1322        pg.set_decimal_point(None)
1323        try:
1324            r = pg.get_decimal_point()
1325        finally:
1326            pg.set_decimal_point(point)
1327        self.assertIsNone(r)
[546]1328
1329    def testSetDecimalPoint(self):
1330        d = pg.Decimal
1331        point = pg.get_decimal_point()
[616]1332        self.assertRaises(TypeError, pg.set_decimal_point)
1333        # error if decimal point is not a string
1334        self.assertRaises(TypeError, pg.set_decimal_point, 0)
1335        # error if more than one decimal point passed
1336        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
1337        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
1338        # error if decimal point is not a punctuation character
1339        self.assertRaises(TypeError, pg.set_decimal_point, '0')
[546]1340        query = self.c.query
[616]1341        # check that money values are interpreted as decimal values
1342        # only if decimal_point is set, and that the result is correct
1343        # only if it is set suitable for the current lc_monetary setting
1344        select_money = "select '34.25'::money"
[644]1345        proper_money = d('34.25')
1346        bad_money = d('3425')
[616]1347        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
1348        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
1349        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
[730]1350        de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25',
[690]1351            'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM')
[604]1352        # first try with English localization (using the point)
[616]1353        for lc in en_locales:
[604]1354            try:
1355                query("set lc_monetary='%s'" % lc)
1356            except pg.ProgrammingError:
1357                pass
1358            else:
1359                break
1360        else:
[643]1361            self.skipTest("cannot set English money locale")
[613]1362        try:
[616]1363            r = query(select_money)
[613]1364        except pg.ProgrammingError:
1365            # this can happen if the currency signs cannot be
1366            # converted using the encoding of the test database
[617]1367            self.skipTest("database does not support English money")
[616]1368        pg.set_decimal_point(None)
1369        try:
1370            r = r.getresult()[0][0]
1371        finally:
1372            pg.set_decimal_point(point)
1373        self.assertIsInstance(r, str)
1374        self.assertIn(r, en_money)
1375        r = query(select_money)
[619]1376        pg.set_decimal_point('')
1377        try:
1378            r = r.getresult()[0][0]
1379        finally:
1380            pg.set_decimal_point(point)
1381        self.assertIsInstance(r, str)
1382        self.assertIn(r, en_money)
1383        r = query(select_money)
[615]1384        pg.set_decimal_point('.')
1385        try:
1386            r = r.getresult()[0][0]
1387        finally:
1388            pg.set_decimal_point(point)
[546]1389        self.assertIsInstance(r, d)
[616]1390        self.assertEqual(r, proper_money)
1391        r = query(select_money)
[546]1392        pg.set_decimal_point(',')
[615]1393        try:
1394            r = r.getresult()[0][0]
1395        finally:
1396            pg.set_decimal_point(point)
[616]1397        self.assertIsInstance(r, d)
1398        self.assertEqual(r, bad_money)
1399        r = query(select_money)
[619]1400        pg.set_decimal_point("'")
[616]1401        try:
1402            r = r.getresult()[0][0]
1403        finally:
1404            pg.set_decimal_point(point)
1405        self.assertIsInstance(r, d)
1406        self.assertEqual(r, bad_money)
[605]1407        # then try with German localization (using the comma)
[616]1408        for lc in de_locales:
[604]1409            try:
1410                query("set lc_monetary='%s'" % lc)
1411            except pg.ProgrammingError:
1412                pass
1413            else:
1414                break
1415        else:
[643]1416            self.skipTest("cannot set German money locale")
[617]1417        select_money = select_money.replace('.', ',')
1418        try:
1419            r = query(select_money)
1420        except pg.ProgrammingError:
1421            self.skipTest("database does not support English money")
[616]1422        pg.set_decimal_point(None)
[613]1423        try:
[616]1424            r = r.getresult()[0][0]
1425        finally:
1426            pg.set_decimal_point(point)
1427        self.assertIsInstance(r, str)
1428        self.assertIn(r, de_money)
1429        r = query(select_money)
[619]1430        pg.set_decimal_point('')
1431        try:
1432            r = r.getresult()[0][0]
1433        finally:
1434            pg.set_decimal_point(point)
1435        self.assertIsInstance(r, str)
1436        self.assertIn(r, de_money)
1437        r = query(select_money)
[615]1438        pg.set_decimal_point(',')
1439        try:
1440            r = r.getresult()[0][0]
1441        finally:
1442            pg.set_decimal_point(point)
[546]1443        self.assertIsInstance(r, d)
[616]1444        self.assertEqual(r, proper_money)
1445        r = query(select_money)
[546]1446        pg.set_decimal_point('.')
[613]1447        try:
[615]1448            r = r.getresult()[0][0]
1449        finally:
1450            pg.set_decimal_point(point)
[616]1451        self.assertEqual(r, bad_money)
1452        r = query(select_money)
[619]1453        pg.set_decimal_point("'")
[616]1454        try:
1455            r = r.getresult()[0][0]
1456        finally:
1457            pg.set_decimal_point(point)
1458        self.assertEqual(r, bad_money)
[546]1459
[635]1460    def testGetDecimal(self):
1461        decimal_class = pg.get_decimal()
1462        # error if a parameter is passed
1463        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
1464        self.assertIs(decimal_class, pg.Decimal)  # the default setting
1465        pg.set_decimal(int)
1466        try:
1467            r = pg.get_decimal()
1468        finally:
1469            pg.set_decimal(decimal_class)
1470        self.assertIs(r, int)
1471        r = pg.get_decimal()
1472        self.assertIs(r, decimal_class)
1473
[546]1474    def testSetDecimal(self):
[635]1475        decimal_class = pg.get_decimal()
1476        # error if no parameter is passed
1477        self.assertRaises(TypeError, pg.set_decimal)
[546]1478        query = self.c.query
[615]1479        try:
1480            r = query("select 3425::numeric")
1481        except pg.ProgrammingError:
1482            self.skipTest('database does not support numeric')
1483        r = r.getresult()[0][0]
[635]1484        self.assertIsInstance(r, decimal_class)
1485        self.assertEqual(r, decimal_class('3425'))
[615]1486        r = query("select 3425::numeric")
[635]1487        pg.set_decimal(int)
[615]1488        try:
1489            r = r.getresult()[0][0]
1490        finally:
[635]1491            pg.set_decimal(decimal_class)
1492        self.assertNotIsInstance(r, decimal_class)
1493        self.assertIsInstance(r, int)
1494        self.assertEqual(r, int(3425))
[546]1495
[634]1496    def testGetBool(self):
1497        use_bool = pg.get_bool()
1498        # error if a parameter is passed
1499        self.assertRaises(TypeError, pg.get_bool, use_bool)
1500        self.assertIsInstance(use_bool, bool)
1501        self.assertIs(use_bool, False)  # the default setting
1502        pg.set_bool(True)
1503        try:
1504            r = pg.get_bool()
1505        finally:
1506            pg.set_bool(use_bool)
1507        self.assertIsInstance(r, bool)
1508        self.assertIs(r, True)
1509        pg.set_bool(False)
1510        try:
1511            r = pg.get_bool()
1512        finally:
1513            pg.set_bool(use_bool)
1514        self.assertIsInstance(r, bool)
1515        self.assertIs(r, False)
1516        pg.set_bool(1)
1517        try:
1518            r = pg.get_bool()
1519        finally:
1520            pg.set_bool(use_bool)
1521        self.assertIsInstance(r, bool)
1522        self.assertIs(r, True)
1523        pg.set_bool(0)
1524        try:
1525            r = pg.get_bool()
1526        finally:
1527            pg.set_bool(use_bool)
1528        self.assertIsInstance(r, bool)
1529        self.assertIs(r, False)
1530
1531    def testSetBool(self):
1532        use_bool = pg.get_bool()
[635]1533        # error if no parameter is passed
1534        self.assertRaises(TypeError, pg.set_bool)
[634]1535        query = self.c.query
1536        try:
1537            r = query("select true::bool")
1538        except pg.ProgrammingError:
1539            self.skipTest('database does not support bool')
1540        r = r.getresult()[0][0]
1541        self.assertIsInstance(r, str)
1542        self.assertEqual(r, 't')
1543        r = query("select true::bool")
1544        pg.set_bool(True)
1545        try:
1546            r = r.getresult()[0][0]
1547        finally:
1548            pg.set_bool(use_bool)
1549        self.assertIsInstance(r, bool)
1550        self.assertIs(r, True)
1551        r = query("select true::bool")
1552        pg.set_bool(False)
1553        try:
1554            r = r.getresult()[0][0]
1555        finally:
1556            pg.set_bool(use_bool)
1557        self.assertIsInstance(r, str)
1558        self.assertIs(r, 't')
1559
[643]1560    def testGetNamedresult(self):
[635]1561        namedresult = pg.get_namedresult()
1562        # error if a parameter is passed
1563        self.assertRaises(TypeError, pg.get_namedresult, namedresult)
1564        self.assertIs(namedresult, pg._namedresult)  # the default setting
1565
[546]1566    def testSetNamedresult(self):
[635]1567        namedresult = pg.get_namedresult()
1568        self.assertTrue(callable(namedresult))
1569
[546]1570        query = self.c.query
1571
1572        r = query("select 1 as x, 2 as y").namedresult()[0]
1573        self.assertIsInstance(r, tuple)
1574        self.assertEqual(r, (1, 2))
1575        self.assertIsNot(type(r), tuple)
1576        self.assertEqual(r._fields, ('x', 'y'))
1577        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1578        self.assertEqual(r.__class__.__name__, 'Row')
1579
[635]1580        def listresult(q):
[574]1581            return [list(row) for row in q.getresult()]
[546]1582
[635]1583        pg.set_namedresult(listresult)
[546]1584        try:
[635]1585            r = pg.get_namedresult()
1586            self.assertIs(r, listresult)
[546]1587            r = query("select 1 as x, 2 as y").namedresult()[0]
1588            self.assertIsInstance(r, list)
1589            self.assertEqual(r, [1, 2])
1590            self.assertIsNot(type(r), tuple)
1591            self.assertFalse(hasattr(r, '_fields'))
1592            self.assertNotEqual(r.__class__.__name__, 'Row')
1593        finally:
[635]1594            pg.set_namedresult(namedresult)
[546]1595
[635]1596        r = pg.get_namedresult()
1597        self.assertIs(r, namedresult)
[546]1598
[635]1599
[602]1600class TestStandaloneEscapeFunctions(unittest.TestCase):
[641]1601    """Test pg escape functions.
[602]1602
1603    The libpq interface memorizes some parameters of the last opened
1604    connection that influence the result of these functions.  Therefore
1605    we need to open a connection with fixed parameters prior to testing
1606    in order to ensure that the tests always run under the same conditions.
1607    That's why these tests are included in this test module.
1608
1609    """
1610
1611    @classmethod
1612    def setUpClass(cls):
1613        query = connect().query
1614        query('set client_encoding=sql_ascii')
1615        query('set standard_conforming_strings=off')
1616        query('set bytea_output=escape')
1617
1618    def testEscapeString(self):
1619        f = pg.escape_string
1620        r = f(b'plain')
[629]1621        self.assertIsInstance(r, bytes)
1622        self.assertEqual(r, b'plain')
[602]1623        r = f(u'plain')
[629]1624        self.assertIsInstance(r, unicode)
1625        self.assertEqual(r, u'plain')
[602]1626        r = f(u"das is' kÀse".encode('utf-8'))
[629]1627        self.assertIsInstance(r, bytes)
1628        self.assertEqual(r, u"das is'' kÀse".encode('utf-8'))
1629        r = f(u"that's cheesy")
1630        self.assertIsInstance(r, unicode)
1631        self.assertEqual(r, u"that''s cheesy")
[602]1632        r = f(r"It's bad to have a \ inside.")
1633        self.assertEqual(r, r"It''s bad to have a \\ inside.")
1634
1635    def testEscapeBytea(self):
1636        f = pg.escape_bytea
1637        r = f(b'plain')
[629]1638        self.assertIsInstance(r, bytes)
1639        self.assertEqual(r, b'plain')
[602]1640        r = f(u'plain')
[629]1641        self.assertIsInstance(r, unicode)
1642        self.assertEqual(r, u'plain')
[602]1643        r = f(u"das is' kÀse".encode('utf-8'))
[629]1644        self.assertIsInstance(r, bytes)
1645        self.assertEqual(r, b"das is'' k\\\\303\\\\244se")
1646        r = f(u"that's cheesy")
1647        self.assertIsInstance(r, unicode)
1648        self.assertEqual(r, u"that''s cheesy")
[602]1649        r = f(b'O\x00ps\xff!')
[629]1650        self.assertEqual(r, b'O\\\\000ps\\\\377!')
[602]1651
1652
[539]1653if __name__ == '__main__':
1654    unittest.main()
Note: See TracBrowser for help on using the repository browser.