source: trunk/tests/test_classic_connection.py @ 894

Last change on this file since 894 was 894, checked in by cito, 3 years ago

Cache the namedtuple classes used for query result rows

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 68.2 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17import threading
18import time
19import os
20
21import pg  # the module under test
22
23from decimal import Decimal
24
25# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
26# get our information from that.  Otherwise we use the defaults.
27# These tests should be run with various PostgreSQL versions and databases
28# created with different encodings and locales.  Particularly, make sure the
29# tests are running against databases created with both SQL_ASCII and UTF8.
30dbname = 'unittest'
31dbhost = None
32dbport = 5432
33
34try:
35    from .LOCAL_PyGreSQL import *
36except (ImportError, ValueError):
37    try:
38        from LOCAL_PyGreSQL import *
39    except ImportError:
40        pass
41
42try:
43    long
44except NameError:  # Python >= 3.0
45    long = int
46
47try:
48    unicode
49except NameError:  # Python >= 3.0
50    unicode = str
51
52unicode_strings = str is not bytes
53
54windows = os.name == 'nt'
55
56# There is a known a bug in libpq under Windows which can cause
57# the interface to crash when calling PQhost():
58do_not_ask_for_host = windows
59do_not_ask_for_host_reason = 'libpq issue on Windows'
60
61
62def connect():
63    """Create a basic pg connection to the test database."""
64    connection = pg.connect(dbname, dbhost, dbport)
65    connection.query("set client_min_messages=warning")
66    return connection
67
68
69class TestCanConnect(unittest.TestCase):
70    """Test whether a basic connection to PostgreSQL is possible."""
71
72    def testCanConnect(self):
73        try:
74            connection = connect()
75        except pg.Error as error:
76            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
77        try:
78            connection.close()
79        except pg.Error:
80            self.fail('Cannot close the database connection')
81
82
83class TestConnectObject(unittest.TestCase):
84    """Test existence of basic pg connection methods."""
85
86    def setUp(self):
87        self.connection = connect()
88
89    def tearDown(self):
90        try:
91            self.connection.close()
92        except pg.InternalError:
93            pass
94
95    def is_method(self, attribute):
96        """Check if given attribute on the connection is a method."""
97        if do_not_ask_for_host and attribute == 'host':
98            return False
99        return callable(getattr(self.connection, attribute))
100
101    def testClassName(self):
102        self.assertEqual(self.connection.__class__.__name__, 'Connection')
103
104    def testModuleName(self):
105        self.assertEqual(self.connection.__class__.__module__, 'pg')
106
107    def testStr(self):
108        r = str(self.connection)
109        self.assertTrue(r.startswith('<pg.Connection object'), r)
110
111    def testRepr(self):
112        r = repr(self.connection)
113        self.assertTrue(r.startswith('<pg.Connection object'), r)
114
115    def testAllConnectAttributes(self):
116        attributes = '''db error host options port
117            protocol_version server_version status user'''.split()
118        connection_attributes = [a for a in dir(self.connection)
119            if not a.startswith('__') and not self.is_method(a)]
120        self.assertEqual(attributes, connection_attributes)
121
122    def testAllConnectMethods(self):
123        methods = '''cancel close date_format endcopy
124            escape_bytea escape_identifier escape_literal escape_string
125            fileno get_cast_hook get_notice_receiver getline getlo getnotify
126            inserttable locreate loimport parameter putline query reset
127            set_cast_hook set_notice_receiver source transaction'''.split()
128        connection_methods = [a for a in dir(self.connection)
129            if not a.startswith('__') and self.is_method(a)]
130        self.assertEqual(methods, connection_methods)
131
132    def testAttributeDb(self):
133        self.assertEqual(self.connection.db, dbname)
134
135    def testAttributeError(self):
136        error = self.connection.error
137        self.assertTrue(not error or 'krb5_' in error)
138
139    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
140    def testAttributeHost(self):
141        def_host = 'localhost'
142        self.assertIsInstance(self.connection.host, str)
143        self.assertEqual(self.connection.host, dbhost or def_host)
144
145    def testAttributeOptions(self):
146        no_options = ''
147        self.assertEqual(self.connection.options, no_options)
148
149    def testAttributePort(self):
150        def_port = 5432
151        self.assertIsInstance(self.connection.port, int)
152        self.assertEqual(self.connection.port, dbport or def_port)
153
154    def testAttributeProtocolVersion(self):
155        protocol_version = self.connection.protocol_version
156        self.assertIsInstance(protocol_version, int)
157        self.assertTrue(2 <= protocol_version < 4)
158
159    def testAttributeServerVersion(self):
160        server_version = self.connection.server_version
161        self.assertIsInstance(server_version, int)
162        self.assertTrue(70400 <= server_version < 100000)
163
164    def testAttributeStatus(self):
165        status_ok = 1
166        self.assertIsInstance(self.connection.status, int)
167        self.assertEqual(self.connection.status, status_ok)
168
169    def testAttributeUser(self):
170        no_user = 'Deprecated facility'
171        user = self.connection.user
172        self.assertTrue(user)
173        self.assertIsInstance(user, str)
174        self.assertNotEqual(user, no_user)
175
176    def testMethodQuery(self):
177        query = self.connection.query
178        query("select 1+1")
179        query("select 1+$1", (1,))
180        query("select 1+$1+$2", (2, 3))
181        query("select 1+$1+$2", [2, 3])
182
183    def testMethodQueryEmpty(self):
184        self.assertRaises(ValueError, self.connection.query, '')
185
186    def testMethodEndcopy(self):
187        try:
188            self.connection.endcopy()
189        except IOError:
190            pass
191
192    def testMethodClose(self):
193        self.connection.close()
194        try:
195            self.connection.reset()
196        except (pg.Error, TypeError):
197            pass
198        else:
199            self.fail('Reset should give an error for a closed connection')
200        self.assertRaises(pg.InternalError, self.connection.close)
201        try:
202            self.connection.query('select 1')
203        except (pg.Error, TypeError):
204            pass
205        else:
206            self.fail('Query should give an error for a closed connection')
207        self.connection = connect()
208
209    def testMethodReset(self):
210        query = self.connection.query
211        # check that client encoding gets reset
212        encoding = query('show client_encoding').getresult()[0][0].upper()
213        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
214        self.assertNotEqual(encoding, changed_encoding)
215        self.connection.query("set client_encoding=%s" % changed_encoding)
216        new_encoding = query('show client_encoding').getresult()[0][0].upper()
217        self.assertEqual(new_encoding, changed_encoding)
218        self.connection.reset()
219        new_encoding = query('show client_encoding').getresult()[0][0].upper()
220        self.assertNotEqual(new_encoding, changed_encoding)
221        self.assertEqual(new_encoding, encoding)
222
223    def testMethodCancel(self):
224        r = self.connection.cancel()
225        self.assertIsInstance(r, int)
226        self.assertEqual(r, 1)
227
228    def testCancelLongRunningThread(self):
229        errors = []
230
231        def sleep():
232            try:
233                self.connection.query('select pg_sleep(5)').getresult()
234            except pg.DatabaseError as error:
235                errors.append(str(error))
236
237        thread = threading.Thread(target=sleep)
238        t1 = time.time()
239        thread.start()  # run the query
240        while 1:  # make sure the query is really running
241            time.sleep(0.1)
242            if thread.is_alive() or time.time() - t1 > 5:
243                break
244        r = self.connection.cancel()  # cancel the running query
245        thread.join()  # wait for the thread to end
246        t2 = time.time()
247
248        self.assertIsInstance(r, int)
249        self.assertEqual(r, 1)  # return code should be 1
250        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
251        self.assertTrue(errors)
252
253    def testMethodFileNo(self):
254        r = self.connection.fileno()
255        self.assertIsInstance(r, int)
256        self.assertGreaterEqual(r, 0)
257
258    def testMethodTransaction(self):
259        transaction = self.connection.transaction
260        self.assertRaises(TypeError, transaction, None)
261        self.assertEqual(transaction(), pg.TRANS_IDLE)
262        self.connection.query('begin')
263        self.assertEqual(transaction(), pg.TRANS_INTRANS)
264        self.connection.query('rollback')
265        self.assertEqual(transaction(), pg.TRANS_IDLE)
266
267    def testMethodParameter(self):
268        parameter = self.connection.parameter
269        query = self.connection.query
270        self.assertRaises(TypeError, parameter)
271        r = parameter('this server setting does not exist')
272        self.assertIsNone(r)
273        s = query('show server_version').getresult()[0][0].upper()
274        self.assertIsNotNone(s)
275        r = parameter('server_version')
276        self.assertEqual(r, s)
277        s = query('show server_encoding').getresult()[0][0].upper()
278        self.assertIsNotNone(s)
279        r = parameter('server_encoding')
280        self.assertEqual(r, s)
281        s = query('show client_encoding').getresult()[0][0].upper()
282        self.assertIsNotNone(s)
283        r = parameter('client_encoding')
284        self.assertEqual(r, s)
285        s = query('show server_encoding').getresult()[0][0].upper()
286        self.assertIsNotNone(s)
287        r = parameter('server_encoding')
288        self.assertEqual(r, s)
289
290
291class TestSimpleQueries(unittest.TestCase):
292    """Test simple queries via a basic pg connection."""
293
294    def setUp(self):
295        self.c = connect()
296
297    def tearDown(self):
298        self.doCleanups()
299        self.c.close()
300
301    def testClassName(self):
302        r = self.c.query("select 1")
303        self.assertEqual(r.__class__.__name__, 'Query')
304
305    def testModuleName(self):
306        r = self.c.query("select 1")
307        self.assertEqual(r.__class__.__module__, 'pg')
308
309    def testStr(self):
310        q = ("select 1 as a, 'hello' as h, 'w' as world"
311            " union select 2, 'xyz', 'uvw'")
312        r = self.c.query(q)
313        self.assertEqual(str(r),
314            'a|  h  |world\n'
315            '-+-----+-----\n'
316            '1|hello|w    \n'
317            '2|xyz  |uvw  \n'
318            '(2 rows)')
319
320    def testRepr(self):
321        r = repr(self.c.query("select 1"))
322        self.assertTrue(r.startswith('<pg.Query object'), r)
323
324    def testSelect0(self):
325        q = "select 0"
326        self.c.query(q)
327
328    def testSelect0Semicolon(self):
329        q = "select 0;"
330        self.c.query(q)
331
332    def testSelectDotSemicolon(self):
333        q = "select .;"
334        self.assertRaises(pg.DatabaseError, self.c.query, q)
335
336    def testGetresult(self):
337        q = "select 0"
338        result = [(0,)]
339        r = self.c.query(q).getresult()
340        self.assertIsInstance(r, list)
341        v = r[0]
342        self.assertIsInstance(v, tuple)
343        self.assertIsInstance(v[0], int)
344        self.assertEqual(r, result)
345
346    def testGetresultLong(self):
347        q = "select 9876543210"
348        result = long(9876543210)
349        self.assertIsInstance(result, long)
350        v = self.c.query(q).getresult()[0][0]
351        self.assertIsInstance(v, long)
352        self.assertEqual(v, result)
353
354    def testGetresultDecimal(self):
355        q = "select 98765432109876543210"
356        result = Decimal(98765432109876543210)
357        v = self.c.query(q).getresult()[0][0]
358        self.assertIsInstance(v, Decimal)
359        self.assertEqual(v, result)
360
361    def testGetresultString(self):
362        result = 'Hello, world!'
363        q = "select '%s'" % result
364        v = self.c.query(q).getresult()[0][0]
365        self.assertIsInstance(v, str)
366        self.assertEqual(v, result)
367
368    def testDictresult(self):
369        q = "select 0 as alias0"
370        result = [{'alias0': 0}]
371        r = self.c.query(q).dictresult()
372        self.assertIsInstance(r, list)
373        v = r[0]
374        self.assertIsInstance(v, dict)
375        self.assertIsInstance(v['alias0'], int)
376        self.assertEqual(r, result)
377
378    def testDictresultLong(self):
379        q = "select 9876543210 as longjohnsilver"
380        result = long(9876543210)
381        self.assertIsInstance(result, long)
382        v = self.c.query(q).dictresult()[0]['longjohnsilver']
383        self.assertIsInstance(v, long)
384        self.assertEqual(v, result)
385
386    def testDictresultDecimal(self):
387        q = "select 98765432109876543210 as longjohnsilver"
388        result = Decimal(98765432109876543210)
389        v = self.c.query(q).dictresult()[0]['longjohnsilver']
390        self.assertIsInstance(v, Decimal)
391        self.assertEqual(v, result)
392
393    def testDictresultString(self):
394        result = 'Hello, world!'
395        q = "select '%s' as greeting" % result
396        v = self.c.query(q).dictresult()[0]['greeting']
397        self.assertIsInstance(v, str)
398        self.assertEqual(v, result)
399
400    def testNamedresult(self):
401        q = "select 0 as alias0"
402        result = [(0,)]
403        r = self.c.query(q).namedresult()
404        self.assertEqual(r, result)
405        v = r[0]
406        self.assertEqual(v._fields, ('alias0',))
407        self.assertEqual(v.alias0, 0)
408
409    def testGet3Cols(self):
410        q = "select 1,2,3"
411        result = [(1, 2, 3)]
412        r = self.c.query(q).getresult()
413        self.assertEqual(r, result)
414
415    def testGet3DictCols(self):
416        q = "select 1 as a,2 as b,3 as c"
417        result = [dict(a=1, b=2, c=3)]
418        r = self.c.query(q).dictresult()
419        self.assertEqual(r, result)
420
421    def testGet3NamedCols(self):
422        q = "select 1 as a,2 as b,3 as c"
423        result = [(1, 2, 3)]
424        r = self.c.query(q).namedresult()
425        self.assertEqual(r, result)
426        v = r[0]
427        self.assertEqual(v._fields, ('a', 'b', 'c'))
428        self.assertEqual(v.b, 2)
429
430    def testGet3Rows(self):
431        q = "select 3 union select 1 union select 2 order by 1"
432        result = [(1,), (2,), (3,)]
433        r = self.c.query(q).getresult()
434        self.assertEqual(r, result)
435
436    def testGet3DictRows(self):
437        q = ("select 3 as alias3"
438            " union select 1 union select 2 order by 1")
439        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
440        r = self.c.query(q).dictresult()
441        self.assertEqual(r, result)
442
443    def testGet3NamedRows(self):
444        q = ("select 3 as alias3"
445            " union select 1 union select 2 order by 1")
446        result = [(1,), (2,), (3,)]
447        r = self.c.query(q).namedresult()
448        self.assertEqual(r, result)
449        for v in r:
450            self.assertEqual(v._fields, ('alias3',))
451
452    def testDictresultNames(self):
453        q = "select 'MixedCase' as MixedCaseAlias"
454        result = [{'mixedcasealias': 'MixedCase'}]
455        r = self.c.query(q).dictresult()
456        self.assertEqual(r, result)
457        q = "select 'MixedCase' as \"MixedCaseAlias\""
458        result = [{'MixedCaseAlias': 'MixedCase'}]
459        r = self.c.query(q).dictresult()
460        self.assertEqual(r, result)
461
462    def testNamedresultNames(self):
463        q = "select 'MixedCase' as MixedCaseAlias"
464        result = [('MixedCase',)]
465        r = self.c.query(q).namedresult()
466        self.assertEqual(r, result)
467        v = r[0]
468        self.assertEqual(v._fields, ('mixedcasealias',))
469        self.assertEqual(v.mixedcasealias, 'MixedCase')
470        q = "select 'MixedCase' as \"MixedCaseAlias\""
471        r = self.c.query(q).namedresult()
472        self.assertEqual(r, result)
473        v = r[0]
474        self.assertEqual(v._fields, ('MixedCaseAlias',))
475        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
476
477    def testBigGetresult(self):
478        num_cols = 100
479        num_rows = 100
480        q = "select " + ','.join(map(str, range(num_cols)))
481        q = ' union all '.join((q,) * num_rows)
482        r = self.c.query(q).getresult()
483        result = [tuple(range(num_cols))] * num_rows
484        self.assertEqual(r, result)
485
486    def testListfields(self):
487        q = ('select 0 as a, 0 as b, 0 as c,'
488            ' 0 as c, 0 as b, 0 as a,'
489            ' 0 as lowercase, 0 as UPPERCASE,'
490            ' 0 as MixedCase, 0 as "MixedCase",'
491            ' 0 as a_long_name_with_underscores,'
492            ' 0 as "A long name with Blanks"')
493        r = self.c.query(q).listfields()
494        self.assertIsInstance(r, tuple)
495        result = ('a', 'b', 'c', 'c', 'b', 'a',
496            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
497            'a_long_name_with_underscores',
498            'A long name with Blanks')
499        self.assertEqual(r, result)
500
501    def testFieldname(self):
502        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
503        r = self.c.query(q).fieldname(2)
504        self.assertEqual(r, 'x')
505        r = self.c.query(q).fieldname(3)
506        self.assertEqual(r, 'y')
507
508    def testFieldnum(self):
509        q = "select 1 as x"
510        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
511        q = "select 1 as x"
512        r = self.c.query(q).fieldnum('x')
513        self.assertIsInstance(r, int)
514        self.assertEqual(r, 0)
515        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
516        r = self.c.query(q).fieldnum('x')
517        self.assertIsInstance(r, int)
518        self.assertEqual(r, 2)
519        r = self.c.query(q).fieldnum('y')
520        self.assertIsInstance(r, int)
521        self.assertEqual(r, 3)
522
523    def testNtuples(self):
524        q = "select 1 where false"
525        r = self.c.query(q).ntuples()
526        self.assertIsInstance(r, int)
527        self.assertEqual(r, 0)
528        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
529            " union select 5 as a, 6 as b, 7 as c, 8 as d")
530        r = self.c.query(q).ntuples()
531        self.assertIsInstance(r, int)
532        self.assertEqual(r, 2)
533        q = ("select 1 union select 2 union select 3"
534            " union select 4 union select 5 union select 6")
535        r = self.c.query(q).ntuples()
536        self.assertIsInstance(r, int)
537        self.assertEqual(r, 6)
538
539    def testQuery(self):
540        query = self.c.query
541        query("drop table if exists test_table")
542        self.addCleanup(query, "drop table test_table")
543        q = "create table test_table (n integer) with oids"
544        r = query(q)
545        self.assertIsNone(r)
546        q = "insert into test_table values (1)"
547        r = query(q)
548        self.assertIsInstance(r, int)
549        q = "insert into test_table select 2"
550        r = query(q)
551        self.assertIsInstance(r, int)
552        oid = r
553        q = "select oid from test_table where n=2"
554        r = query(q).getresult()
555        self.assertEqual(len(r), 1)
556        r = r[0]
557        self.assertEqual(len(r), 1)
558        r = r[0]
559        self.assertIsInstance(r, int)
560        self.assertEqual(r, oid)
561        q = "insert into test_table select 3 union select 4 union select 5"
562        r = query(q)
563        self.assertIsInstance(r, str)
564        self.assertEqual(r, '3')
565        q = "update test_table set n=4 where n<5"
566        r = query(q)
567        self.assertIsInstance(r, str)
568        self.assertEqual(r, '4')
569        q = "delete from test_table"
570        r = query(q)
571        self.assertIsInstance(r, str)
572        self.assertEqual(r, '5')
573
574
575class TestUnicodeQueries(unittest.TestCase):
576    """Test unicode strings as queries via a basic pg connection."""
577
578    def setUp(self):
579        self.c = connect()
580        self.c.query('set client_encoding=utf8')
581
582    def tearDown(self):
583        self.c.close()
584
585    def testGetresulAscii(self):
586        result = u'Hello, world!'
587        q = u"select '%s'" % result
588        v = self.c.query(q).getresult()[0][0]
589        self.assertIsInstance(v, str)
590        self.assertEqual(v, result)
591
592    def testDictresulAscii(self):
593        result = u'Hello, world!'
594        q = u"select '%s' as greeting" % result
595        v = self.c.query(q).dictresult()[0]['greeting']
596        self.assertIsInstance(v, str)
597        self.assertEqual(v, result)
598
599    def testGetresultUtf8(self):
600        result = u'Hello, wörld & ЌОр!'
601        q = u"select '%s'" % result
602        if not unicode_strings:
603            result = result.encode('utf8')
604        # pass the query as unicode
605        try:
606            v = self.c.query(q).getresult()[0][0]
607        except(pg.DataError, pg.NotSupportedError):
608            self.skipTest("database does not support utf8")
609        self.assertIsInstance(v, str)
610        self.assertEqual(v, result)
611        q = q.encode('utf8')
612        # pass the query as bytes
613        v = self.c.query(q).getresult()[0][0]
614        self.assertIsInstance(v, str)
615        self.assertEqual(v, result)
616
617    def testDictresultUtf8(self):
618        result = u'Hello, wörld & ЌОр!'
619        q = u"select '%s' as greeting" % result
620        if not unicode_strings:
621            result = result.encode('utf8')
622        try:
623            v = self.c.query(q).dictresult()[0]['greeting']
624        except (pg.DataError, pg.NotSupportedError):
625            self.skipTest("database does not support utf8")
626        self.assertIsInstance(v, str)
627        self.assertEqual(v, result)
628        q = q.encode('utf8')
629        v = self.c.query(q).dictresult()[0]['greeting']
630        self.assertIsInstance(v, str)
631        self.assertEqual(v, result)
632
633    def testDictresultLatin1(self):
634        try:
635            self.c.query('set client_encoding=latin1')
636        except (pg.DataError, pg.NotSupportedError):
637            self.skipTest("database does not support latin1")
638        result = u'Hello, wörld!'
639        q = u"select '%s'" % result
640        if not unicode_strings:
641            result = result.encode('latin1')
642        v = self.c.query(q).getresult()[0][0]
643        self.assertIsInstance(v, str)
644        self.assertEqual(v, result)
645        q = q.encode('latin1')
646        v = self.c.query(q).getresult()[0][0]
647        self.assertIsInstance(v, str)
648        self.assertEqual(v, result)
649
650    def testDictresultLatin1(self):
651        try:
652            self.c.query('set client_encoding=latin1')
653        except (pg.DataError, pg.NotSupportedError):
654            self.skipTest("database does not support latin1")
655        result = u'Hello, wörld!'
656        q = u"select '%s' as greeting" % result
657        if not unicode_strings:
658            result = result.encode('latin1')
659        v = self.c.query(q).dictresult()[0]['greeting']
660        self.assertIsInstance(v, str)
661        self.assertEqual(v, result)
662        q = q.encode('latin1')
663        v = self.c.query(q).dictresult()[0]['greeting']
664        self.assertIsInstance(v, str)
665        self.assertEqual(v, result)
666
667    def testGetresultCyrillic(self):
668        try:
669            self.c.query('set client_encoding=iso_8859_5')
670        except (pg.DataError, pg.NotSupportedError):
671            self.skipTest("database does not support cyrillic")
672        result = u'Hello, ЌОр!'
673        q = u"select '%s'" % result
674        if not unicode_strings:
675            result = result.encode('cyrillic')
676        v = self.c.query(q).getresult()[0][0]
677        self.assertIsInstance(v, str)
678        self.assertEqual(v, result)
679        q = q.encode('cyrillic')
680        v = self.c.query(q).getresult()[0][0]
681        self.assertIsInstance(v, str)
682        self.assertEqual(v, result)
683
684    def testDictresultCyrillic(self):
685        try:
686            self.c.query('set client_encoding=iso_8859_5')
687        except (pg.DataError, pg.NotSupportedError):
688            self.skipTest("database does not support cyrillic")
689        result = u'Hello, ЌОр!'
690        q = u"select '%s' as greeting" % result
691        if not unicode_strings:
692            result = result.encode('cyrillic')
693        v = self.c.query(q).dictresult()[0]['greeting']
694        self.assertIsInstance(v, str)
695        self.assertEqual(v, result)
696        q = q.encode('cyrillic')
697        v = self.c.query(q).dictresult()[0]['greeting']
698        self.assertIsInstance(v, str)
699        self.assertEqual(v, result)
700
701    def testGetresultLatin9(self):
702        try:
703            self.c.query('set client_encoding=latin9')
704        except (pg.DataError, pg.NotSupportedError):
705            self.skipTest("database does not support latin9")
706        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
707        q = u"select '%s'" % result
708        if not unicode_strings:
709            result = result.encode('latin9')
710        v = self.c.query(q).getresult()[0][0]
711        self.assertIsInstance(v, str)
712        self.assertEqual(v, result)
713        q = q.encode('latin9')
714        v = self.c.query(q).getresult()[0][0]
715        self.assertIsInstance(v, str)
716        self.assertEqual(v, result)
717
718    def testDictresultLatin9(self):
719        try:
720            self.c.query('set client_encoding=latin9')
721        except (pg.DataError, pg.NotSupportedError):
722            self.skipTest("database does not support latin9")
723        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
724        q = u"select '%s' as menu" % result
725        if not unicode_strings:
726            result = result.encode('latin9')
727        v = self.c.query(q).dictresult()[0]['menu']
728        self.assertIsInstance(v, str)
729        self.assertEqual(v, result)
730        q = q.encode('latin9')
731        v = self.c.query(q).dictresult()[0]['menu']
732        self.assertIsInstance(v, str)
733        self.assertEqual(v, result)
734
735
736class TestParamQueries(unittest.TestCase):
737    """Test queries with parameters via a basic pg connection."""
738
739    def setUp(self):
740        self.c = connect()
741        self.c.query('set client_encoding=utf8')
742
743    def tearDown(self):
744        self.c.close()
745
746    def testQueryWithNoneParam(self):
747        self.assertRaises(TypeError, self.c.query, "select $1", None)
748        self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None)
749        self.assertEqual(self.c.query("select $1::integer", (None,)
750            ).getresult(), [(None,)])
751        self.assertEqual(self.c.query("select $1::text", [None]
752            ).getresult(), [(None,)])
753        self.assertEqual(self.c.query("select $1::text", [[None]]
754            ).getresult(), [(None,)])
755
756    def testQueryWithBoolParams(self, bool_enabled=None):
757        query = self.c.query
758        if bool_enabled is not None:
759            bool_enabled_default = pg.get_bool()
760            pg.set_bool(bool_enabled)
761        try:
762            bool_on = bool_enabled or bool_enabled is None
763            v_false, v_true = (False, True) if bool_on else 'ft'
764            r_false, r_true = [(v_false,)], [(v_true,)]
765            self.assertEqual(query("select false").getresult(), r_false)
766            self.assertEqual(query("select true").getresult(), r_true)
767            q = "select $1::bool"
768            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
769            self.assertEqual(query(q, ('f',)).getresult(), r_false)
770            self.assertEqual(query(q, ('t',)).getresult(), r_true)
771            self.assertEqual(query(q, ('false',)).getresult(), r_false)
772            self.assertEqual(query(q, ('true',)).getresult(), r_true)
773            self.assertEqual(query(q, ('n',)).getresult(), r_false)
774            self.assertEqual(query(q, ('y',)).getresult(), r_true)
775            self.assertEqual(query(q, (0,)).getresult(), r_false)
776            self.assertEqual(query(q, (1,)).getresult(), r_true)
777            self.assertEqual(query(q, (False,)).getresult(), r_false)
778            self.assertEqual(query(q, (True,)).getresult(), r_true)
779        finally:
780            if bool_enabled is not None:
781                pg.set_bool(bool_enabled_default)
782
783    def testQueryWithBoolParamsNotDefault(self):
784        self.testQueryWithBoolParams(bool_enabled=not pg.get_bool())
785
786    def testQueryWithIntParams(self):
787        query = self.c.query
788        self.assertEqual(query("select 1+1").getresult(), [(2,)])
789        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
790        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
791        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
792        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
793        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
794            [(Decimal('2'),)])
795        self.assertEqual(query("select 1, $1::integer", (2,)
796            ).getresult(), [(1, 2)])
797        self.assertEqual(query("select 1 union select $1::integer", (2,)
798            ).getresult(), [(1,), (2,)])
799        self.assertEqual(query("select $1::integer+$2", (1, 2)
800            ).getresult(), [(3,)])
801        self.assertEqual(query("select $1::integer+$2", [1, 2]
802            ).getresult(), [(3,)])
803        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))
804            ).getresult(), [(15,)])
805
806    def testQueryWithStrParams(self):
807        query = self.c.query
808        self.assertEqual(query("select $1||', world!'", ('Hello',)
809            ).getresult(), [('Hello, world!',)])
810        self.assertEqual(query("select $1||', world!'", ['Hello']
811            ).getresult(), [('Hello, world!',)])
812        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
813            ).getresult(), [('Hello, world!',)])
814        self.assertEqual(query("select $1::text", ('Hello, world!',)
815            ).getresult(), [('Hello, world!',)])
816        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
817            ).getresult(), [('Hello', 'world')])
818        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
819            ).getresult(), [('Hello', 'world')])
820        self.assertEqual(query("select $1::text union select $2::text",
821            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
822        try:
823            query("select 'wörld'")
824        except (pg.DataError, pg.NotSupportedError):
825            self.skipTest('database does not support utf8')
826        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
827            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
828
829    def testQueryWithUnicodeParams(self):
830        query = self.c.query
831        try:
832            query('set client_encoding=utf8')
833            query("select 'wörld'").getresult()[0][0] == 'wörld'
834        except (pg.DataError, pg.NotSupportedError):
835            self.skipTest("database does not support utf8")
836        self.assertEqual(query("select $1||', '||$2||'!'",
837            ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)])
838
839    def testQueryWithUnicodeParamsLatin1(self):
840        query = self.c.query
841        try:
842            query('set client_encoding=latin1')
843            query("select 'wörld'").getresult()[0][0] == 'wörld'
844        except (pg.DataError, pg.NotSupportedError):
845            self.skipTest("database does not support latin1")
846        r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult()
847        if unicode_strings:
848            self.assertEqual(r, [('Hello, wörld!',)])
849        else:
850            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
851        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
852            ('Hello', u'ЌОр'))
853        query('set client_encoding=iso_8859_1')
854        r = query("select $1||', '||$2||'!'",
855            ('Hello', u'wörld')).getresult()
856        if unicode_strings:
857            self.assertEqual(r, [('Hello, wörld!',)])
858        else:
859            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
860        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
861            ('Hello', u'ЌОр'))
862        query('set client_encoding=sql_ascii')
863        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
864            ('Hello', u'wörld'))
865
866    def testQueryWithUnicodeParamsCyrillic(self):
867        query = self.c.query
868        try:
869            query('set client_encoding=iso_8859_5')
870            query("select 'ЌОр'").getresult()[0][0] == 'ЌОр'
871        except (pg.DataError, pg.NotSupportedError):
872            self.skipTest("database does not support cyrillic")
873        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
874            ('Hello', u'wörld'))
875        r = query("select $1||', '||$2||'!'",
876            ('Hello', u'ЌОр')).getresult()
877        if unicode_strings:
878            self.assertEqual(r, [('Hello, ЌОр!',)])
879        else:
880            self.assertEqual(r, [(u'Hello, ЌОр!'.encode('cyrillic'),)])
881        query('set client_encoding=sql_ascii')
882        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
883            ('Hello', u'ЌОр!'))
884
885    def testQueryWithMixedParams(self):
886        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
887            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
888        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
889            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
890
891    def testQueryWithDuplicateParams(self):
892        self.assertRaises(pg.ProgrammingError,
893            self.c.query, "select $1+$1", (1,))
894        self.assertRaises(pg.ProgrammingError,
895            self.c.query, "select $1+$1", (1, 2))
896
897    def testQueryWithZeroParams(self):
898        self.assertEqual(self.c.query("select 1+1", []
899            ).getresult(), [(2,)])
900
901    def testQueryWithGarbage(self):
902        garbage = r"'\{}+()-#[]oo324"
903        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
904            ).dictresult(), [{'garbage': garbage}])
905
906
907class TestQueryResultTypes(unittest.TestCase):
908    """Test proper result types via a basic pg connection."""
909
910    def setUp(self):
911        self.c = connect()
912        self.c.query('set client_encoding=utf8')
913        self.c.query("set datestyle='ISO,YMD'")
914        self.c.query("set timezone='UTC'")
915
916    def tearDown(self):
917        self.c.close()
918
919    def assert_proper_cast(self, value, pgtype, pytype):
920        q = 'select $1::%s' % (pgtype,)
921        try:
922            r = self.c.query(q, (value,)).getresult()[0][0]
923        except pg.ProgrammingError:
924            if pgtype in ('json', 'jsonb'):
925                self.skipTest('database does not support json')
926        self.assertIsInstance(r, pytype)
927        if isinstance(value, str):
928            if not value or ' ' in value or '{' in value:
929                value = '"%s"' % value
930        value = '{%s}' % value
931        r = self.c.query(q + '[]', (value,)).getresult()[0][0]
932        if pgtype.startswith(('date', 'time', 'interval')):
933            # arrays of these are casted by the DB wrapper only
934            self.assertEqual(r, value)
935        else:
936            self.assertIsInstance(r, list)
937            self.assertEqual(len(r), 1)
938            self.assertIsInstance(r[0], pytype)
939
940    def testInt(self):
941        self.assert_proper_cast(0, 'int', int)
942        self.assert_proper_cast(0, 'smallint', int)
943        self.assert_proper_cast(0, 'oid', int)
944        self.assert_proper_cast(0, 'cid', int)
945        self.assert_proper_cast(0, 'xid', int)
946
947    def testLong(self):
948        self.assert_proper_cast(0, 'bigint', long)
949
950    def testFloat(self):
951        self.assert_proper_cast(0, 'float', float)
952        self.assert_proper_cast(0, 'real', float)
953        self.assert_proper_cast(0, 'double', float)
954        self.assert_proper_cast(0, 'double precision', float)
955        self.assert_proper_cast('infinity', 'float', float)
956
957    def testFloat(self):
958        decimal = pg.get_decimal()
959        self.assert_proper_cast(decimal(0), 'numeric', decimal)
960        self.assert_proper_cast(decimal(0), 'decimal', decimal)
961
962    def testMoney(self):
963        decimal = pg.get_decimal()
964        self.assert_proper_cast(decimal('0'), 'money', decimal)
965
966    def testBool(self):
967        bool_type = bool if pg.get_bool() else str
968        self.assert_proper_cast('f', 'bool', bool_type)
969
970    def testDate(self):
971        self.assert_proper_cast('1956-01-31', 'date', str)
972        self.assert_proper_cast('10:20:30', 'interval', str)
973        self.assert_proper_cast('08:42:15', 'time', str)
974        self.assert_proper_cast('08:42:15+00', 'timetz', str)
975        self.assert_proper_cast('1956-01-31 08:42:15', 'timestamp', str)
976        self.assert_proper_cast('1956-01-31 08:42:15+00', 'timestamptz', str)
977
978    def testText(self):
979        self.assert_proper_cast('', 'text', str)
980        self.assert_proper_cast('', 'char', str)
981        self.assert_proper_cast('', 'bpchar', str)
982        self.assert_proper_cast('', 'varchar', str)
983
984    def testBytea(self):
985        self.assert_proper_cast('', 'bytea', bytes)
986
987    def testJson(self):
988        self.assert_proper_cast('{}', 'json', dict)
989
990
991class TestInserttable(unittest.TestCase):
992    """Test inserttable method."""
993
994    cls_set_up = False
995
996    @classmethod
997    def setUpClass(cls):
998        c = connect()
999        c.query("drop table if exists test cascade")
1000        c.query("create table test ("
1001            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
1002            "d numeric, f4 real, f8 double precision, m money,"
1003            "c char(1), v4 varchar(4), c4 char(4), t text)")
1004        # Check whether the test database uses SQL_ASCII - this means
1005        # that it does not consider encoding when calculating lengths.
1006        c.query("set client_encoding=utf8")
1007        try:
1008            c.query("select 'À'")
1009        except (pg.DataError, pg.NotSupportedError):
1010            cls.has_encoding = False
1011        else:
1012            cls.has_encoding = c.query(
1013                "select length('À') - length('a')").getresult()[0][0] == 0
1014        c.close()
1015        cls.cls_set_up = True
1016
1017    @classmethod
1018    def tearDownClass(cls):
1019        c = connect()
1020        c.query("drop table test cascade")
1021        c.close()
1022
1023    def setUp(self):
1024        self.assertTrue(self.cls_set_up)
1025        self.c = connect()
1026        self.c.query("set client_encoding=utf8")
1027        self.c.query("set datestyle='ISO,YMD'")
1028        self.c.query("set lc_monetary='C'")
1029
1030    def tearDown(self):
1031        self.c.query("truncate table test")
1032        self.c.close()
1033
1034    data = [
1035        (-1, -1, long(-1), True, '1492-10-12', '08:30:00',
1036            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
1037        (0, 0, long(0), False, '1607-04-14', '09:00:00',
1038            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
1039        (1, 1, long(1), True, '1801-03-04', '03:45:00',
1040            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
1041        (2, 2, long(2), False, '1903-12-17', '11:22:00',
1042            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
1043
1044    @classmethod
1045    def db_len(cls, s, encoding):
1046        if cls.has_encoding:
1047            s = s if isinstance(s, unicode) else s.decode(encoding)
1048        else:
1049            s = s.encode(encoding) if isinstance(s, unicode) else s
1050        return len(s)
1051
1052    def get_back(self, encoding='utf-8'):
1053        """Convert boolean and decimal values back."""
1054        data = []
1055        for row in self.c.query("select * from test order by 1").getresult():
1056            self.assertIsInstance(row, tuple)
1057            row = list(row)
1058            if row[0] is not None:  # smallint
1059                self.assertIsInstance(row[0], int)
1060            if row[1] is not None:  # integer
1061                self.assertIsInstance(row[1], int)
1062            if row[2] is not None:  # bigint
1063                self.assertIsInstance(row[2], long)
1064            if row[3] is not None:  # boolean
1065                self.assertIsInstance(row[3], bool)
1066            if row[4] is not None:  # date
1067                self.assertIsInstance(row[4], str)
1068                self.assertTrue(row[4].replace('-', '').isdigit())
1069            if row[5] is not None:  # time
1070                self.assertIsInstance(row[5], str)
1071                self.assertTrue(row[5].replace(':', '').isdigit())
1072            if row[6] is not None:  # numeric
1073                self.assertIsInstance(row[6], Decimal)
1074                row[6] = float(row[6])
1075            if row[7] is not None:  # real
1076                self.assertIsInstance(row[7], float)
1077            if row[8] is not None:  # double precision
1078                self.assertIsInstance(row[8], float)
1079                row[8] = float(row[8])
1080            if row[9] is not None:  # money
1081                self.assertIsInstance(row[9], Decimal)
1082                row[9] = str(float(row[9]))
1083            if row[10] is not None:  # char(1)
1084                self.assertIsInstance(row[10], str)
1085                self.assertEqual(self.db_len(row[10], encoding), 1)
1086            if row[11] is not None:  # varchar(4)
1087                self.assertIsInstance(row[11], str)
1088                self.assertLessEqual(self.db_len(row[11], encoding), 4)
1089            if row[12] is not None:  # char(4)
1090                self.assertIsInstance(row[12], str)
1091                self.assertEqual(self.db_len(row[12], encoding), 4)
1092                row[12] = row[12].rstrip()
1093            if row[13] is not None:  # text
1094                self.assertIsInstance(row[13], str)
1095            row = tuple(row)
1096            data.append(row)
1097        return data
1098
1099    def testInserttable1Row(self):
1100        data = self.data[2:3]
1101        self.c.inserttable('test', data)
1102        self.assertEqual(self.get_back(), data)
1103
1104    def testInserttable4Rows(self):
1105        data = self.data
1106        self.c.inserttable('test', data)
1107        self.assertEqual(self.get_back(), data)
1108
1109    def testInserttableMultipleRows(self):
1110        num_rows = 100
1111        data = self.data[2:3] * num_rows
1112        self.c.inserttable('test', data)
1113        r = self.c.query("select count(*) from test").getresult()[0][0]
1114        self.assertEqual(r, num_rows)
1115
1116    def testInserttableMultipleCalls(self):
1117        num_rows = 10
1118        data = self.data[2:3]
1119        for _i in range(num_rows):
1120            self.c.inserttable('test', data)
1121        r = self.c.query("select count(*) from test").getresult()[0][0]
1122        self.assertEqual(r, num_rows)
1123
1124    def testInserttableNullValues(self):
1125        data = [(None,) * 14] * 100
1126        self.c.inserttable('test', data)
1127        self.assertEqual(self.get_back(), data)
1128
1129    def testInserttableMaxValues(self):
1130        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
1131            True, '2999-12-31', '11:59:59', 1e99,
1132            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
1133            "1", "1234", "1234", "1234" * 100)]
1134        self.c.inserttable('test', data)
1135        self.assertEqual(self.get_back(), data)
1136
1137    def testInserttableByteValues(self):
1138        try:
1139            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1140        except pg.DataError:
1141            self.skipTest("database does not support utf8")
1142        # non-ascii chars do not fit in char(1) when there is no encoding
1143        c = u'€' if self.has_encoding else u'$'
1144        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1145            0.0, 0.0, 0.0, u'0.0',
1146            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1147        row_bytes = tuple(s.encode('utf-8')
1148            if isinstance(s, unicode) else s for s in row_unicode)
1149        data = [row_bytes] * 2
1150        self.c.inserttable('test', data)
1151        if unicode_strings:
1152            data = [row_unicode] * 2
1153        self.assertEqual(self.get_back(), data)
1154
1155    def testInserttableUnicodeUtf8(self):
1156        try:
1157            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1158        except pg.DataError:
1159            self.skipTest("database does not support utf8")
1160        # non-ascii chars do not fit in char(1) when there is no encoding
1161        c = u'€' if self.has_encoding else u'$'
1162        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1163            0.0, 0.0, 0.0, u'0.0',
1164            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1165        data = [row_unicode] * 2
1166        self.c.inserttable('test', data)
1167        if not unicode_strings:
1168            row_bytes = tuple(s.encode('utf-8')
1169                if isinstance(s, unicode) else s for s in row_unicode)
1170            data = [row_bytes] * 2
1171        self.assertEqual(self.get_back(), data)
1172
1173    def testInserttableUnicodeLatin1(self):
1174
1175        try:
1176            self.c.query("set client_encoding=latin1")
1177            self.c.query("select 'Â¥'")
1178        except (pg.DataError, pg.NotSupportedError):
1179            self.skipTest("database does not support latin1")
1180        # non-ascii chars do not fit in char(1) when there is no encoding
1181        c = u'€' if self.has_encoding else u'$'
1182        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1183            0.0, 0.0, 0.0, u'0.0',
1184            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1185        data = [row_unicode]
1186        # cannot encode € sign with latin1 encoding
1187        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1188        row_unicode = tuple(s.replace(u'€', u'Â¥')
1189            if isinstance(s, unicode) else s for s in row_unicode)
1190        data = [row_unicode] * 2
1191        self.c.inserttable('test', data)
1192        if not unicode_strings:
1193            row_bytes = tuple(s.encode('latin1')
1194                if isinstance(s, unicode) else s for s in row_unicode)
1195            data = [row_bytes] * 2
1196        self.assertEqual(self.get_back('latin1'), data)
1197
1198    def testInserttableUnicodeLatin9(self):
1199        try:
1200            self.c.query("set client_encoding=latin9")
1201            self.c.query("select '€'")
1202        except (pg.DataError, pg.NotSupportedError):
1203            self.skipTest("database does not support latin9")
1204            return
1205        # non-ascii chars do not fit in char(1) when there is no encoding
1206        c = u'€' if self.has_encoding else u'$'
1207        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1208            0.0, 0.0, 0.0, u'0.0',
1209            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1210        data = [row_unicode] * 2
1211        self.c.inserttable('test', data)
1212        if not unicode_strings:
1213            row_bytes = tuple(s.encode('latin9')
1214                if isinstance(s, unicode) else s for s in row_unicode)
1215            data = [row_bytes] * 2
1216        self.assertEqual(self.get_back('latin9'), data)
1217
1218    def testInserttableNoEncoding(self):
1219        self.c.query("set client_encoding=sql_ascii")
1220        # non-ascii chars do not fit in char(1) when there is no encoding
1221        c = u'€' if self.has_encoding else u'$'
1222        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1223            0.0, 0.0, 0.0, u'0.0',
1224            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1225        data = [row_unicode]
1226        # cannot encode non-ascii unicode without a specific encoding
1227        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1228
1229
1230class TestDirectSocketAccess(unittest.TestCase):
1231    """Test copy command with direct socket access."""
1232
1233    cls_set_up = False
1234
1235    @classmethod
1236    def setUpClass(cls):
1237        c = connect()
1238        c.query("drop table if exists test cascade")
1239        c.query("create table test (i int, v varchar(16))")
1240        c.close()
1241        cls.cls_set_up = True
1242
1243    @classmethod
1244    def tearDownClass(cls):
1245        c = connect()
1246        c.query("drop table test cascade")
1247        c.close()
1248
1249    def setUp(self):
1250        self.assertTrue(self.cls_set_up)
1251        self.c = connect()
1252        self.c.query("set client_encoding=utf8")
1253
1254    def tearDown(self):
1255        self.c.query("truncate table test")
1256        self.c.close()
1257
1258    def testPutline(self):
1259        putline = self.c.putline
1260        query = self.c.query
1261        data = list(enumerate("apple pear plum cherry banana".split()))
1262        query("copy test from stdin")
1263        try:
1264            for i, v in data:
1265                putline("%d\t%s\n" % (i, v))
1266            putline("\\.\n")
1267        finally:
1268            self.c.endcopy()
1269        r = query("select * from test").getresult()
1270        self.assertEqual(r, data)
1271
1272    def testPutlineBytesAndUnicode(self):
1273        putline = self.c.putline
1274        query = self.c.query
1275        try:
1276            query("select 'kÀse+wÃŒrstel'")
1277        except (pg.DataError, pg.NotSupportedError):
1278            self.skipTest('database does not support utf8')
1279        query("copy test from stdin")
1280        try:
1281            putline(u"47\tkÀse\n".encode('utf8'))
1282            putline("35\twÃŒrstel\n")
1283            putline(b"\\.\n")
1284        finally:
1285            self.c.endcopy()
1286        r = query("select * from test").getresult()
1287        self.assertEqual(r, [(47, 'kÀse'), (35, 'wÃŒrstel')])
1288
1289    def testGetline(self):
1290        getline = self.c.getline
1291        query = self.c.query
1292        data = list(enumerate("apple banana pear plum strawberry".split()))
1293        n = len(data)
1294        self.c.inserttable('test', data)
1295        query("copy test to stdout")
1296        try:
1297            for i in range(n + 2):
1298                v = getline()
1299                if i < n:
1300                    self.assertEqual(v, '%d\t%s' % data[i])
1301                elif i == n:
1302                    self.assertEqual(v, '\\.')
1303                else:
1304                    self.assertIsNone(v)
1305        finally:
1306            try:
1307                self.c.endcopy()
1308            except IOError:
1309                pass
1310
1311    def testGetlineBytesAndUnicode(self):
1312        getline = self.c.getline
1313        query = self.c.query
1314        try:
1315            query("select 'kÀse+wÃŒrstel'")
1316        except (pg.DataError, pg.NotSupportedError):
1317            self.skipTest('database does not support utf8')
1318        data = [(54, u'kÀse'.encode('utf8')), (73, u'wÃŒrstel')]
1319        self.c.inserttable('test', data)
1320        query("copy test to stdout")
1321        try:
1322            v = getline()
1323            self.assertIsInstance(v, str)
1324            self.assertEqual(v, '54\tkÀse')
1325            v = getline()
1326            self.assertIsInstance(v, str)
1327            self.assertEqual(v, '73\twÃŒrstel')
1328            self.assertEqual(getline(), '\\.')
1329            self.assertIsNone(getline())
1330        finally:
1331            try:
1332                self.c.endcopy()
1333            except IOError:
1334                pass
1335
1336    def testParameterChecks(self):
1337        self.assertRaises(TypeError, self.c.putline)
1338        self.assertRaises(TypeError, self.c.getline, 'invalid')
1339        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
1340
1341
1342class TestNotificatons(unittest.TestCase):
1343    """Test notification support."""
1344
1345    def setUp(self):
1346        self.c = connect()
1347
1348    def tearDown(self):
1349        self.doCleanups()
1350        self.c.close()
1351
1352    def testGetNotify(self):
1353        getnotify = self.c.getnotify
1354        query = self.c.query
1355        self.assertIsNone(getnotify())
1356        query('listen test_notify')
1357        try:
1358            self.assertIsNone(self.c.getnotify())
1359            query("notify test_notify")
1360            r = getnotify()
1361            self.assertIsInstance(r, tuple)
1362            self.assertEqual(len(r), 3)
1363            self.assertIsInstance(r[0], str)
1364            self.assertIsInstance(r[1], int)
1365            self.assertIsInstance(r[2], str)
1366            self.assertEqual(r[0], 'test_notify')
1367            self.assertEqual(r[2], '')
1368            self.assertIsNone(self.c.getnotify())
1369            query("notify test_notify, 'test_payload'")
1370            r = getnotify()
1371            self.assertTrue(isinstance(r, tuple))
1372            self.assertEqual(len(r), 3)
1373            self.assertIsInstance(r[0], str)
1374            self.assertIsInstance(r[1], int)
1375            self.assertIsInstance(r[2], str)
1376            self.assertEqual(r[0], 'test_notify')
1377            self.assertEqual(r[2], 'test_payload')
1378            self.assertIsNone(getnotify())
1379        finally:
1380            query('unlisten test_notify')
1381
1382    def testGetNoticeReceiver(self):
1383        self.assertIsNone(self.c.get_notice_receiver())
1384
1385    def testSetNoticeReceiver(self):
1386        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
1387        self.assertRaises(TypeError, self.c.set_notice_receiver, 'invalid')
1388        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
1389        self.assertIsNone(self.c.set_notice_receiver(None))
1390
1391    def testSetAndGetNoticeReceiver(self):
1392        r = lambda notice: None
1393        self.assertIsNone(self.c.set_notice_receiver(r))
1394        self.assertIs(self.c.get_notice_receiver(), r)
1395        self.assertIsNone(self.c.set_notice_receiver(None))
1396        self.assertIsNone(self.c.get_notice_receiver())
1397
1398    def testNoticeReceiver(self):
1399        self.addCleanup(self.c.query, 'drop function bilbo_notice();')
1400        self.c.query('''create function bilbo_notice() returns void AS $$
1401            begin
1402                raise warning 'Bilbo was here!';
1403            end;
1404            $$ language plpgsql''')
1405        received = {}
1406
1407        def notice_receiver(notice):
1408            for attr in dir(notice):
1409                if attr.startswith('__'):
1410                    continue
1411                value = getattr(notice, attr)
1412                if isinstance(value, str):
1413                    value = value.replace('WARNUNG', 'WARNING')
1414                received[attr] = value
1415
1416        self.c.set_notice_receiver(notice_receiver)
1417        self.c.query('select bilbo_notice()')
1418        self.assertEqual(received, dict(
1419            pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
1420            severity='WARNING', primary='Bilbo was here!',
1421            detail=None, hint=None))
1422
1423
1424class TestConfigFunctions(unittest.TestCase):
1425    """Test the functions for changing default settings.
1426
1427    To test the effect of most of these functions, we need a database
1428    connection.  That's why they are covered in this test module.
1429    """
1430
1431    def setUp(self):
1432        self.c = connect()
1433        self.c.query("set client_encoding=utf8")
1434        self.c.query('set bytea_output=hex')
1435        self.c.query("set lc_monetary='C'")
1436
1437    def tearDown(self):
1438        self.c.close()
1439
1440    def testGetDecimalPoint(self):
1441        point = pg.get_decimal_point()
1442        # error if a parameter is passed
1443        self.assertRaises(TypeError, pg.get_decimal_point, point)
1444        self.assertIsInstance(point, str)
1445        self.assertEqual(point, '.')  # the default setting
1446        pg.set_decimal_point(',')
1447        try:
1448            r = pg.get_decimal_point()
1449        finally:
1450            pg.set_decimal_point(point)
1451        self.assertIsInstance(r, str)
1452        self.assertEqual(r, ',')
1453        pg.set_decimal_point("'")
1454        try:
1455            r = pg.get_decimal_point()
1456        finally:
1457            pg.set_decimal_point(point)
1458        self.assertIsInstance(r, str)
1459        self.assertEqual(r, "'")
1460        pg.set_decimal_point('')
1461        try:
1462            r = pg.get_decimal_point()
1463        finally:
1464            pg.set_decimal_point(point)
1465        self.assertIsNone(r)
1466        pg.set_decimal_point(None)
1467        try:
1468            r = pg.get_decimal_point()
1469        finally:
1470            pg.set_decimal_point(point)
1471        self.assertIsNone(r)
1472
1473    def testSetDecimalPoint(self):
1474        d = pg.Decimal
1475        point = pg.get_decimal_point()
1476        self.assertRaises(TypeError, pg.set_decimal_point)
1477        # error if decimal point is not a string
1478        self.assertRaises(TypeError, pg.set_decimal_point, 0)
1479        # error if more than one decimal point passed
1480        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
1481        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
1482        # error if decimal point is not a punctuation character
1483        self.assertRaises(TypeError, pg.set_decimal_point, '0')
1484        query = self.c.query
1485        # check that money values are interpreted as decimal values
1486        # only if decimal_point is set, and that the result is correct
1487        # only if it is set suitable for the current lc_monetary setting
1488        select_money = "select '34.25'::money"
1489        proper_money = d('34.25')
1490        bad_money = d('3425')
1491        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
1492        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
1493        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
1494        de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25',
1495            'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM')
1496        # first try with English localization (using the point)
1497        for lc in en_locales:
1498            try:
1499                query("set lc_monetary='%s'" % lc)
1500            except pg.DataError:
1501                pass
1502            else:
1503                break
1504        else:
1505            self.skipTest("cannot set English money locale")
1506        try:
1507            r = query(select_money)
1508        except pg.DataError:
1509            # this can happen if the currency signs cannot be
1510            # converted using the encoding of the test database
1511            self.skipTest("database does not support English money")
1512        pg.set_decimal_point(None)
1513        try:
1514            r = r.getresult()[0][0]
1515        finally:
1516            pg.set_decimal_point(point)
1517        self.assertIsInstance(r, str)
1518        self.assertIn(r, en_money)
1519        r = query(select_money)
1520        pg.set_decimal_point('')
1521        try:
1522            r = r.getresult()[0][0]
1523        finally:
1524            pg.set_decimal_point(point)
1525        self.assertIsInstance(r, str)
1526        self.assertIn(r, en_money)
1527        r = query(select_money)
1528        pg.set_decimal_point('.')
1529        try:
1530            r = r.getresult()[0][0]
1531        finally:
1532            pg.set_decimal_point(point)
1533        self.assertIsInstance(r, d)
1534        self.assertEqual(r, proper_money)
1535        r = query(select_money)
1536        pg.set_decimal_point(',')
1537        try:
1538            r = r.getresult()[0][0]
1539        finally:
1540            pg.set_decimal_point(point)
1541        self.assertIsInstance(r, d)
1542        self.assertEqual(r, bad_money)
1543        r = query(select_money)
1544        pg.set_decimal_point("'")
1545        try:
1546            r = r.getresult()[0][0]
1547        finally:
1548            pg.set_decimal_point(point)
1549        self.assertIsInstance(r, d)
1550        self.assertEqual(r, bad_money)
1551        # then try with German localization (using the comma)
1552        for lc in de_locales:
1553            try:
1554                query("set lc_monetary='%s'" % lc)
1555            except pg.DataError:
1556                pass
1557            else:
1558                break
1559        else:
1560            self.skipTest("cannot set German money locale")
1561        select_money = select_money.replace('.', ',')
1562        try:
1563            r = query(select_money)
1564        except pg.DataError:
1565            self.skipTest("database does not support English money")
1566        pg.set_decimal_point(None)
1567        try:
1568            r = r.getresult()[0][0]
1569        finally:
1570            pg.set_decimal_point(point)
1571        self.assertIsInstance(r, str)
1572        self.assertIn(r, de_money)
1573        r = query(select_money)
1574        pg.set_decimal_point('')
1575        try:
1576            r = r.getresult()[0][0]
1577        finally:
1578            pg.set_decimal_point(point)
1579        self.assertIsInstance(r, str)
1580        self.assertIn(r, de_money)
1581        r = query(select_money)
1582        pg.set_decimal_point(',')
1583        try:
1584            r = r.getresult()[0][0]
1585        finally:
1586            pg.set_decimal_point(point)
1587        self.assertIsInstance(r, d)
1588        self.assertEqual(r, proper_money)
1589        r = query(select_money)
1590        pg.set_decimal_point('.')
1591        try:
1592            r = r.getresult()[0][0]
1593        finally:
1594            pg.set_decimal_point(point)
1595        self.assertEqual(r, bad_money)
1596        r = query(select_money)
1597        pg.set_decimal_point("'")
1598        try:
1599            r = r.getresult()[0][0]
1600        finally:
1601            pg.set_decimal_point(point)
1602        self.assertEqual(r, bad_money)
1603
1604    def testGetDecimal(self):
1605        decimal_class = pg.get_decimal()
1606        # error if a parameter is passed
1607        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
1608        self.assertIs(decimal_class, pg.Decimal)  # the default setting
1609        pg.set_decimal(int)
1610        try:
1611            r = pg.get_decimal()
1612        finally:
1613            pg.set_decimal(decimal_class)
1614        self.assertIs(r, int)
1615        r = pg.get_decimal()
1616        self.assertIs(r, decimal_class)
1617
1618    def testSetDecimal(self):
1619        decimal_class = pg.get_decimal()
1620        # error if no parameter is passed
1621        self.assertRaises(TypeError, pg.set_decimal)
1622        query = self.c.query
1623        try:
1624            r = query("select 3425::numeric")
1625        except pg.DatabaseError:
1626            self.skipTest('database does not support numeric')
1627        r = r.getresult()[0][0]
1628        self.assertIsInstance(r, decimal_class)
1629        self.assertEqual(r, decimal_class('3425'))
1630        r = query("select 3425::numeric")
1631        pg.set_decimal(int)
1632        try:
1633            r = r.getresult()[0][0]
1634        finally:
1635            pg.set_decimal(decimal_class)
1636        self.assertNotIsInstance(r, decimal_class)
1637        self.assertIsInstance(r, int)
1638        self.assertEqual(r, int(3425))
1639
1640    def testGetBool(self):
1641        use_bool = pg.get_bool()
1642        # error if a parameter is passed
1643        self.assertRaises(TypeError, pg.get_bool, use_bool)
1644        self.assertIsInstance(use_bool, bool)
1645        self.assertIs(use_bool, True)  # the default setting
1646        pg.set_bool(False)
1647        try:
1648            r = pg.get_bool()
1649        finally:
1650            pg.set_bool(use_bool)
1651        self.assertIsInstance(r, bool)
1652        self.assertIs(r, False)
1653        pg.set_bool(True)
1654        try:
1655            r = pg.get_bool()
1656        finally:
1657            pg.set_bool(use_bool)
1658        self.assertIsInstance(r, bool)
1659        self.assertIs(r, True)
1660        pg.set_bool(0)
1661        try:
1662            r = pg.get_bool()
1663        finally:
1664            pg.set_bool(use_bool)
1665        self.assertIsInstance(r, bool)
1666        self.assertIs(r, False)
1667        pg.set_bool(1)
1668        try:
1669            r = pg.get_bool()
1670        finally:
1671            pg.set_bool(use_bool)
1672        self.assertIsInstance(r, bool)
1673        self.assertIs(r, True)
1674
1675    def testSetBool(self):
1676        use_bool = pg.get_bool()
1677        # error if no parameter is passed
1678        self.assertRaises(TypeError, pg.set_bool)
1679        query = self.c.query
1680        try:
1681            r = query("select true::bool")
1682        except pg.ProgrammingError:
1683            self.skipTest('database does not support bool')
1684        r = r.getresult()[0][0]
1685        self.assertIsInstance(r, bool)
1686        self.assertEqual(r, True)
1687        r = query("select true::bool")
1688        pg.set_bool(False)
1689        try:
1690            r = r.getresult()[0][0]
1691        finally:
1692            pg.set_bool(use_bool)
1693        self.assertIsInstance(r, str)
1694        self.assertIs(r, 't')
1695        r = query("select true::bool")
1696        pg.set_bool(True)
1697        try:
1698            r = r.getresult()[0][0]
1699        finally:
1700            pg.set_bool(use_bool)
1701        self.assertIsInstance(r, bool)
1702        self.assertIs(r, True)
1703
1704    def testGetByteEscaped(self):
1705        bytea_escaped = pg.get_bytea_escaped()
1706        # error if a parameter is passed
1707        self.assertRaises(TypeError, pg.get_bytea_escaped, bytea_escaped)
1708        self.assertIsInstance(bytea_escaped, bool)
1709        self.assertIs(bytea_escaped, False)  # the default setting
1710        pg.set_bytea_escaped(True)
1711        try:
1712            r = pg.get_bytea_escaped()
1713        finally:
1714            pg.set_bytea_escaped(bytea_escaped)
1715        self.assertIsInstance(r, bool)
1716        self.assertIs(r, True)
1717        pg.set_bytea_escaped(False)
1718        try:
1719            r = pg.get_bytea_escaped()
1720        finally:
1721            pg.set_bytea_escaped(bytea_escaped)
1722        self.assertIsInstance(r, bool)
1723        self.assertIs(r, False)
1724        pg.set_bytea_escaped(1)
1725        try:
1726            r = pg.get_bytea_escaped()
1727        finally:
1728            pg.set_bytea_escaped(bytea_escaped)
1729        self.assertIsInstance(r, bool)
1730        self.assertIs(r, True)
1731        pg.set_bytea_escaped(0)
1732        try:
1733            r = pg.get_bytea_escaped()
1734        finally:
1735            pg.set_bytea_escaped(bytea_escaped)
1736        self.assertIsInstance(r, bool)
1737        self.assertIs(r, False)
1738
1739    def testSetByteaEscaped(self):
1740        bytea_escaped = pg.get_bytea_escaped()
1741        # error if no parameter is passed
1742        self.assertRaises(TypeError, pg.set_bytea_escaped)
1743        query = self.c.query
1744        try:
1745            r = query("select 'data'::bytea")
1746        except pg.ProgrammingError:
1747            self.skipTest('database does not support bytea')
1748        r = r.getresult()[0][0]
1749        self.assertIsInstance(r, bytes)
1750        self.assertEqual(r, b'data')
1751        r = query("select 'data'::bytea")
1752        pg.set_bytea_escaped(True)
1753        try:
1754            r = r.getresult()[0][0]
1755        finally:
1756            pg.set_bytea_escaped(bytea_escaped)
1757        self.assertIsInstance(r, str)
1758        self.assertEqual(r, '\\x64617461')
1759        r = query("select 'data'::bytea")
1760        pg.set_bytea_escaped(False)
1761        try:
1762            r = r.getresult()[0][0]
1763        finally:
1764            pg.set_bytea_escaped(bytea_escaped)
1765        self.assertIsInstance(r, bytes)
1766        self.assertEqual(r, b'data')
1767
1768    def testGetNamedresult(self):
1769        namedresult = pg.get_namedresult()
1770        # error if a parameter is passed
1771        self.assertRaises(TypeError, pg.get_namedresult, namedresult)
1772        self.assertIs(namedresult, pg._namedresult)  # the default setting
1773
1774    def testSetNamedresult(self):
1775        namedresult = pg.get_namedresult()
1776        self.assertTrue(callable(namedresult))
1777
1778        query = self.c.query
1779
1780        r = query("select 1 as x, 2 as y").namedresult()[0]
1781        self.assertIsInstance(r, tuple)
1782        self.assertEqual(r, (1, 2))
1783        self.assertIsNot(type(r), tuple)
1784        self.assertEqual(r._fields, ('x', 'y'))
1785        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1786        self.assertEqual(r.__class__.__name__, 'Row')
1787
1788        def listresult(q):
1789            return [list(row) for row in q.getresult()]
1790
1791        pg.set_namedresult(listresult)
1792        try:
1793            r = pg.get_namedresult()
1794            self.assertIs(r, listresult)
1795            r = query("select 1 as x, 2 as y").namedresult()[0]
1796            self.assertIsInstance(r, list)
1797            self.assertEqual(r, [1, 2])
1798            self.assertIsNot(type(r), tuple)
1799            self.assertFalse(hasattr(r, '_fields'))
1800            self.assertNotEqual(r.__class__.__name__, 'Row')
1801        finally:
1802            pg.set_namedresult(namedresult)
1803
1804        r = pg.get_namedresult()
1805        self.assertIs(r, namedresult)
1806
1807    def testSetRowFactorySize(self):
1808        try:
1809            from functools import lru_cache
1810        except ImportError:  # Python < 3.2
1811            lru_cache = None
1812        queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc']
1813        query = self.c.query
1814        for maxsize in (None, 0, 1, 2, 3, 10, 1024):
1815            pg.set_row_factory_size(maxsize)
1816            for i in range(3):
1817                for q in queries:
1818                    r = query(q).namedresult()[0]
1819                    if q.endswith('abc'):
1820                        self.assertEqual(r, (123,))
1821                        self.assertEqual(r._fields, ('abc',))
1822                    else:
1823                        self.assertEqual(r, (1, 2, 3))
1824                        self.assertEqual(r._fields, ('a', 'b', 'c'))
1825            if lru_cache:
1826                info = pg._row_factory.cache_info()
1827                self.assertEqual(info.maxsize, maxsize)
1828                self.assertEqual(info.hits + info.misses, 6)
1829                self.assertEqual(info.hits,
1830                    0 if maxsize is not None and maxsize < 2 else 4)
1831
1832
1833class TestStandaloneEscapeFunctions(unittest.TestCase):
1834    """Test pg escape functions.
1835
1836    The libpq interface memorizes some parameters of the last opened
1837    connection that influence the result of these functions.  Therefore
1838    we need to open a connection with fixed parameters prior to testing
1839    in order to ensure that the tests always run under the same conditions.
1840    That's why these tests are included in this test module.
1841    """
1842
1843    cls_set_up = False
1844
1845    @classmethod
1846    def setUpClass(cls):
1847        db = connect()
1848        query = db.query
1849        query('set client_encoding=sql_ascii')
1850        query('set standard_conforming_strings=off')
1851        try:
1852            query('set bytea_output=escape')
1853        except pg.ProgrammingError:
1854            if db.server_version >= 90000:
1855                raise  # ignore for older server versions
1856        db.close()
1857        cls.cls_set_up = True
1858
1859    def testEscapeString(self):
1860        self.assertTrue(self.cls_set_up)
1861        f = pg.escape_string
1862        r = f(b'plain')
1863        self.assertIsInstance(r, bytes)
1864        self.assertEqual(r, b'plain')
1865        r = f(u'plain')
1866        self.assertIsInstance(r, unicode)
1867        self.assertEqual(r, u'plain')
1868        r = f(u"das is' kÀse".encode('utf-8'))
1869        self.assertIsInstance(r, bytes)
1870        self.assertEqual(r, u"das is'' kÀse".encode('utf-8'))
1871        r = f(u"that's cheesy")
1872        self.assertIsInstance(r, unicode)
1873        self.assertEqual(r, u"that''s cheesy")
1874        r = f(r"It's bad to have a \ inside.")
1875        self.assertEqual(r, r"It''s bad to have a \\ inside.")
1876
1877    def testEscapeBytea(self):
1878        self.assertTrue(self.cls_set_up)
1879        f = pg.escape_bytea
1880        r = f(b'plain')
1881        self.assertIsInstance(r, bytes)
1882        self.assertEqual(r, b'plain')
1883        r = f(u'plain')
1884        self.assertIsInstance(r, unicode)
1885        self.assertEqual(r, u'plain')
1886        r = f(u"das is' kÀse".encode('utf-8'))
1887        self.assertIsInstance(r, bytes)
1888        self.assertEqual(r, b"das is'' k\\\\303\\\\244se")
1889        r = f(u"that's cheesy")
1890        self.assertIsInstance(r, unicode)
1891        self.assertEqual(r, u"that''s cheesy")
1892        r = f(b'O\x00ps\xff!')
1893        self.assertEqual(r, b'O\\\\000ps\\\\377!')
1894
1895
1896if __name__ == '__main__':
1897    unittest.main()
Note: See TracBrowser for help on using the repository browser.