source: trunk/tests/test_classic_connection.py @ 901

Last change on this file since 901 was 901, checked in by cito, 3 years ago

Improve creation of named tuples in Python 2.6 and 3.0

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 69.2 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17import threading
18import time
19import os
20
21from collections import namedtuple
22from decimal import Decimal
23
24import pg  # the module under test
25
26# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
27# get our information from that.  Otherwise we use the defaults.
28# These tests should be run with various PostgreSQL versions and databases
29# created with different encodings and locales.  Particularly, make sure the
30# tests are running against databases created with both SQL_ASCII and UTF8.
31dbname = 'unittest'
32dbhost = None
33dbport = 5432
34
35try:
36    from .LOCAL_PyGreSQL import *
37except (ImportError, ValueError):
38    try:
39        from LOCAL_PyGreSQL import *
40    except ImportError:
41        pass
42
43try:
44    long
45except NameError:  # Python >= 3.0
46    long = int
47
48try:
49    unicode
50except NameError:  # Python >= 3.0
51    unicode = str
52
53unicode_strings = str is not bytes
54
55windows = os.name == 'nt'
56
57# There is a known a bug in libpq under Windows which can cause
58# the interface to crash when calling PQhost():
59do_not_ask_for_host = windows
60do_not_ask_for_host_reason = 'libpq issue on Windows'
61
62
63def connect():
64    """Create a basic pg connection to the test database."""
65    connection = pg.connect(dbname, dbhost, dbport)
66    connection.query("set client_min_messages=warning")
67    return connection
68
69
70class TestCanConnect(unittest.TestCase):
71    """Test whether a basic connection to PostgreSQL is possible."""
72
73    def testCanConnect(self):
74        try:
75            connection = connect()
76        except pg.Error as error:
77            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
78        try:
79            connection.close()
80        except pg.Error:
81            self.fail('Cannot close the database connection')
82
83
84class TestConnectObject(unittest.TestCase):
85    """Test existence of basic pg connection methods."""
86
87    def setUp(self):
88        self.connection = connect()
89
90    def tearDown(self):
91        try:
92            self.connection.close()
93        except pg.InternalError:
94            pass
95
96    def is_method(self, attribute):
97        """Check if given attribute on the connection is a method."""
98        if do_not_ask_for_host and attribute == 'host':
99            return False
100        return callable(getattr(self.connection, attribute))
101
102    def testClassName(self):
103        self.assertEqual(self.connection.__class__.__name__, 'Connection')
104
105    def testModuleName(self):
106        self.assertEqual(self.connection.__class__.__module__, 'pg')
107
108    def testStr(self):
109        r = str(self.connection)
110        self.assertTrue(r.startswith('<pg.Connection object'), r)
111
112    def testRepr(self):
113        r = repr(self.connection)
114        self.assertTrue(r.startswith('<pg.Connection object'), r)
115
116    def testAllConnectAttributes(self):
117        attributes = '''db error host options port
118            protocol_version server_version status user'''.split()
119        connection_attributes = [a for a in dir(self.connection)
120            if not a.startswith('__') and not self.is_method(a)]
121        self.assertEqual(attributes, connection_attributes)
122
123    def testAllConnectMethods(self):
124        methods = '''cancel close date_format endcopy
125            escape_bytea escape_identifier escape_literal escape_string
126            fileno get_cast_hook get_notice_receiver getline getlo getnotify
127            inserttable locreate loimport parameter putline query reset
128            set_cast_hook set_notice_receiver source transaction'''.split()
129        connection_methods = [a for a in dir(self.connection)
130            if not a.startswith('__') and self.is_method(a)]
131        self.assertEqual(methods, connection_methods)
132
133    def testAttributeDb(self):
134        self.assertEqual(self.connection.db, dbname)
135
136    def testAttributeError(self):
137        error = self.connection.error
138        self.assertTrue(not error or 'krb5_' in error)
139
140    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
141    def testAttributeHost(self):
142        def_host = 'localhost'
143        self.assertIsInstance(self.connection.host, str)
144        self.assertEqual(self.connection.host, dbhost or def_host)
145
146    def testAttributeOptions(self):
147        no_options = ''
148        self.assertEqual(self.connection.options, no_options)
149
150    def testAttributePort(self):
151        def_port = 5432
152        self.assertIsInstance(self.connection.port, int)
153        self.assertEqual(self.connection.port, dbport or def_port)
154
155    def testAttributeProtocolVersion(self):
156        protocol_version = self.connection.protocol_version
157        self.assertIsInstance(protocol_version, int)
158        self.assertTrue(2 <= protocol_version < 4)
159
160    def testAttributeServerVersion(self):
161        server_version = self.connection.server_version
162        self.assertIsInstance(server_version, int)
163        self.assertTrue(70400 <= server_version < 100000)
164
165    def testAttributeStatus(self):
166        status_ok = 1
167        self.assertIsInstance(self.connection.status, int)
168        self.assertEqual(self.connection.status, status_ok)
169
170    def testAttributeUser(self):
171        no_user = 'Deprecated facility'
172        user = self.connection.user
173        self.assertTrue(user)
174        self.assertIsInstance(user, str)
175        self.assertNotEqual(user, no_user)
176
177    def testMethodQuery(self):
178        query = self.connection.query
179        query("select 1+1")
180        query("select 1+$1", (1,))
181        query("select 1+$1+$2", (2, 3))
182        query("select 1+$1+$2", [2, 3])
183
184    def testMethodQueryEmpty(self):
185        self.assertRaises(ValueError, self.connection.query, '')
186
187    def testMethodEndcopy(self):
188        try:
189            self.connection.endcopy()
190        except IOError:
191            pass
192
193    def testMethodClose(self):
194        self.connection.close()
195        try:
196            self.connection.reset()
197        except (pg.Error, TypeError):
198            pass
199        else:
200            self.fail('Reset should give an error for a closed connection')
201        self.assertRaises(pg.InternalError, self.connection.close)
202        try:
203            self.connection.query('select 1')
204        except (pg.Error, TypeError):
205            pass
206        else:
207            self.fail('Query should give an error for a closed connection')
208        self.connection = connect()
209
210    def testMethodReset(self):
211        query = self.connection.query
212        # check that client encoding gets reset
213        encoding = query('show client_encoding').getresult()[0][0].upper()
214        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
215        self.assertNotEqual(encoding, changed_encoding)
216        self.connection.query("set client_encoding=%s" % changed_encoding)
217        new_encoding = query('show client_encoding').getresult()[0][0].upper()
218        self.assertEqual(new_encoding, changed_encoding)
219        self.connection.reset()
220        new_encoding = query('show client_encoding').getresult()[0][0].upper()
221        self.assertNotEqual(new_encoding, changed_encoding)
222        self.assertEqual(new_encoding, encoding)
223
224    def testMethodCancel(self):
225        r = self.connection.cancel()
226        self.assertIsInstance(r, int)
227        self.assertEqual(r, 1)
228
229    def testCancelLongRunningThread(self):
230        errors = []
231
232        def sleep():
233            try:
234                self.connection.query('select pg_sleep(5)').getresult()
235            except pg.DatabaseError as error:
236                errors.append(str(error))
237
238        thread = threading.Thread(target=sleep)
239        t1 = time.time()
240        thread.start()  # run the query
241        while 1:  # make sure the query is really running
242            time.sleep(0.1)
243            if thread.is_alive() or time.time() - t1 > 5:
244                break
245        r = self.connection.cancel()  # cancel the running query
246        thread.join()  # wait for the thread to end
247        t2 = time.time()
248
249        self.assertIsInstance(r, int)
250        self.assertEqual(r, 1)  # return code should be 1
251        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
252        self.assertTrue(errors)
253
254    def testMethodFileNo(self):
255        r = self.connection.fileno()
256        self.assertIsInstance(r, int)
257        self.assertGreaterEqual(r, 0)
258
259    def testMethodTransaction(self):
260        transaction = self.connection.transaction
261        self.assertRaises(TypeError, transaction, None)
262        self.assertEqual(transaction(), pg.TRANS_IDLE)
263        self.connection.query('begin')
264        self.assertEqual(transaction(), pg.TRANS_INTRANS)
265        self.connection.query('rollback')
266        self.assertEqual(transaction(), pg.TRANS_IDLE)
267
268    def testMethodParameter(self):
269        parameter = self.connection.parameter
270        query = self.connection.query
271        self.assertRaises(TypeError, parameter)
272        r = parameter('this server setting does not exist')
273        self.assertIsNone(r)
274        s = query('show server_version').getresult()[0][0].upper()
275        self.assertIsNotNone(s)
276        r = parameter('server_version')
277        self.assertEqual(r, s)
278        s = query('show server_encoding').getresult()[0][0].upper()
279        self.assertIsNotNone(s)
280        r = parameter('server_encoding')
281        self.assertEqual(r, s)
282        s = query('show client_encoding').getresult()[0][0].upper()
283        self.assertIsNotNone(s)
284        r = parameter('client_encoding')
285        self.assertEqual(r, s)
286        s = query('show server_encoding').getresult()[0][0].upper()
287        self.assertIsNotNone(s)
288        r = parameter('server_encoding')
289        self.assertEqual(r, s)
290
291
292class TestSimpleQueries(unittest.TestCase):
293    """Test simple queries via a basic pg connection."""
294
295    def setUp(self):
296        self.c = connect()
297
298    def tearDown(self):
299        self.doCleanups()
300        self.c.close()
301
302    def testClassName(self):
303        r = self.c.query("select 1")
304        self.assertEqual(r.__class__.__name__, 'Query')
305
306    def testModuleName(self):
307        r = self.c.query("select 1")
308        self.assertEqual(r.__class__.__module__, 'pg')
309
310    def testStr(self):
311        q = ("select 1 as a, 'hello' as h, 'w' as world"
312            " union select 2, 'xyz', 'uvw'")
313        r = self.c.query(q)
314        self.assertEqual(str(r),
315            'a|  h  |world\n'
316            '-+-----+-----\n'
317            '1|hello|w    \n'
318            '2|xyz  |uvw  \n'
319            '(2 rows)')
320
321    def testRepr(self):
322        r = repr(self.c.query("select 1"))
323        self.assertTrue(r.startswith('<pg.Query object'), r)
324
325    def testSelect0(self):
326        q = "select 0"
327        self.c.query(q)
328
329    def testSelect0Semicolon(self):
330        q = "select 0;"
331        self.c.query(q)
332
333    def testSelectDotSemicolon(self):
334        q = "select .;"
335        self.assertRaises(pg.DatabaseError, self.c.query, q)
336
337    def testGetresult(self):
338        q = "select 0"
339        result = [(0,)]
340        r = self.c.query(q).getresult()
341        self.assertIsInstance(r, list)
342        v = r[0]
343        self.assertIsInstance(v, tuple)
344        self.assertIsInstance(v[0], int)
345        self.assertEqual(r, result)
346
347    def testGetresultLong(self):
348        q = "select 9876543210"
349        result = long(9876543210)
350        self.assertIsInstance(result, long)
351        v = self.c.query(q).getresult()[0][0]
352        self.assertIsInstance(v, long)
353        self.assertEqual(v, result)
354
355    def testGetresultDecimal(self):
356        q = "select 98765432109876543210"
357        result = Decimal(98765432109876543210)
358        v = self.c.query(q).getresult()[0][0]
359        self.assertIsInstance(v, Decimal)
360        self.assertEqual(v, result)
361
362    def testGetresultString(self):
363        result = 'Hello, world!'
364        q = "select '%s'" % result
365        v = self.c.query(q).getresult()[0][0]
366        self.assertIsInstance(v, str)
367        self.assertEqual(v, result)
368
369    def testDictresult(self):
370        q = "select 0 as alias0"
371        result = [{'alias0': 0}]
372        r = self.c.query(q).dictresult()
373        self.assertIsInstance(r, list)
374        v = r[0]
375        self.assertIsInstance(v, dict)
376        self.assertIsInstance(v['alias0'], int)
377        self.assertEqual(r, result)
378
379    def testDictresultLong(self):
380        q = "select 9876543210 as longjohnsilver"
381        result = long(9876543210)
382        self.assertIsInstance(result, long)
383        v = self.c.query(q).dictresult()[0]['longjohnsilver']
384        self.assertIsInstance(v, long)
385        self.assertEqual(v, result)
386
387    def testDictresultDecimal(self):
388        q = "select 98765432109876543210 as longjohnsilver"
389        result = Decimal(98765432109876543210)
390        v = self.c.query(q).dictresult()[0]['longjohnsilver']
391        self.assertIsInstance(v, Decimal)
392        self.assertEqual(v, result)
393
394    def testDictresultString(self):
395        result = 'Hello, world!'
396        q = "select '%s' as greeting" % result
397        v = self.c.query(q).dictresult()[0]['greeting']
398        self.assertIsInstance(v, str)
399        self.assertEqual(v, result)
400
401    def testNamedresult(self):
402        q = "select 0 as alias0"
403        result = [(0,)]
404        r = self.c.query(q).namedresult()
405        self.assertEqual(r, result)
406        v = r[0]
407        self.assertEqual(v._fields, ('alias0',))
408        self.assertEqual(v.alias0, 0)
409
410    def testNamedresultWithGoodFieldnames(self):
411        q = 'select 1 as snake_case_alias, 2 as "CamelCaseAlias"'
412        result = [(1, 2)]
413        r = self.c.query(q).namedresult()
414        self.assertEqual(r, result)
415        v = r[0]
416        self.assertEqual(v._fields, ('snake_case_alias', 'CamelCaseAlias'))
417
418    def testNamedresultWithBadFieldnames(self):
419        try:
420            r = namedtuple('Bad', ['?'] * 6, rename=True)
421        except TypeError:  # Python 2.6 or 3.0
422            fields = tuple('column_%d' % n for n in range(6))
423        else:
424            fields = r._fields
425        q = ('select 3 as "0alias", 4 as _alias, 5 as "alias$", 6 as "alias?",'
426            ' 7 as "kebap-case-alias", 8 as break, 9 as and_a_good_one')
427        result = [tuple(range(3, 10))]
428        r = self.c.query(q).namedresult()
429        self.assertEqual(r, result)
430        v = r[0]
431        self.assertEqual(v._fields[:6], fields)
432        self.assertEqual(v._fields[6], 'and_a_good_one')
433
434    def testGet3Cols(self):
435        q = "select 1,2,3"
436        result = [(1, 2, 3)]
437        r = self.c.query(q).getresult()
438        self.assertEqual(r, result)
439
440    def testGet3DictCols(self):
441        q = "select 1 as a,2 as b,3 as c"
442        result = [dict(a=1, b=2, c=3)]
443        r = self.c.query(q).dictresult()
444        self.assertEqual(r, result)
445
446    def testGet3NamedCols(self):
447        q = "select 1 as a,2 as b,3 as c"
448        result = [(1, 2, 3)]
449        r = self.c.query(q).namedresult()
450        self.assertEqual(r, result)
451        v = r[0]
452        self.assertEqual(v._fields, ('a', 'b', 'c'))
453        self.assertEqual(v.b, 2)
454
455    def testGet3Rows(self):
456        q = "select 3 union select 1 union select 2 order by 1"
457        result = [(1,), (2,), (3,)]
458        r = self.c.query(q).getresult()
459        self.assertEqual(r, result)
460
461    def testGet3DictRows(self):
462        q = ("select 3 as alias3"
463            " union select 1 union select 2 order by 1")
464        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
465        r = self.c.query(q).dictresult()
466        self.assertEqual(r, result)
467
468    def testGet3NamedRows(self):
469        q = ("select 3 as alias3"
470            " union select 1 union select 2 order by 1")
471        result = [(1,), (2,), (3,)]
472        r = self.c.query(q).namedresult()
473        self.assertEqual(r, result)
474        for v in r:
475            self.assertEqual(v._fields, ('alias3',))
476
477    def testDictresultNames(self):
478        q = "select 'MixedCase' as MixedCaseAlias"
479        result = [{'mixedcasealias': 'MixedCase'}]
480        r = self.c.query(q).dictresult()
481        self.assertEqual(r, result)
482        q = "select 'MixedCase' as \"MixedCaseAlias\""
483        result = [{'MixedCaseAlias': 'MixedCase'}]
484        r = self.c.query(q).dictresult()
485        self.assertEqual(r, result)
486
487    def testNamedresultNames(self):
488        q = "select 'MixedCase' as MixedCaseAlias"
489        result = [('MixedCase',)]
490        r = self.c.query(q).namedresult()
491        self.assertEqual(r, result)
492        v = r[0]
493        self.assertEqual(v._fields, ('mixedcasealias',))
494        self.assertEqual(v.mixedcasealias, 'MixedCase')
495        q = "select 'MixedCase' as \"MixedCaseAlias\""
496        r = self.c.query(q).namedresult()
497        self.assertEqual(r, result)
498        v = r[0]
499        self.assertEqual(v._fields, ('MixedCaseAlias',))
500        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
501
502    def testBigGetresult(self):
503        num_cols = 100
504        num_rows = 100
505        q = "select " + ','.join(map(str, range(num_cols)))
506        q = ' union all '.join((q,) * num_rows)
507        r = self.c.query(q).getresult()
508        result = [tuple(range(num_cols))] * num_rows
509        self.assertEqual(r, result)
510
511    def testListfields(self):
512        q = ('select 0 as a, 0 as b, 0 as c,'
513            ' 0 as c, 0 as b, 0 as a,'
514            ' 0 as lowercase, 0 as UPPERCASE,'
515            ' 0 as MixedCase, 0 as "MixedCase",'
516            ' 0 as a_long_name_with_underscores,'
517            ' 0 as "A long name with Blanks"')
518        r = self.c.query(q).listfields()
519        self.assertIsInstance(r, tuple)
520        result = ('a', 'b', 'c', 'c', 'b', 'a',
521            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
522            'a_long_name_with_underscores',
523            'A long name with Blanks')
524        self.assertEqual(r, result)
525
526    def testFieldname(self):
527        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
528        r = self.c.query(q).fieldname(2)
529        self.assertEqual(r, 'x')
530        r = self.c.query(q).fieldname(3)
531        self.assertEqual(r, 'y')
532
533    def testFieldnum(self):
534        q = "select 1 as x"
535        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
536        q = "select 1 as x"
537        r = self.c.query(q).fieldnum('x')
538        self.assertIsInstance(r, int)
539        self.assertEqual(r, 0)
540        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
541        r = self.c.query(q).fieldnum('x')
542        self.assertIsInstance(r, int)
543        self.assertEqual(r, 2)
544        r = self.c.query(q).fieldnum('y')
545        self.assertIsInstance(r, int)
546        self.assertEqual(r, 3)
547
548    def testNtuples(self):
549        q = "select 1 where false"
550        r = self.c.query(q).ntuples()
551        self.assertIsInstance(r, int)
552        self.assertEqual(r, 0)
553        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
554            " union select 5 as a, 6 as b, 7 as c, 8 as d")
555        r = self.c.query(q).ntuples()
556        self.assertIsInstance(r, int)
557        self.assertEqual(r, 2)
558        q = ("select 1 union select 2 union select 3"
559            " union select 4 union select 5 union select 6")
560        r = self.c.query(q).ntuples()
561        self.assertIsInstance(r, int)
562        self.assertEqual(r, 6)
563
564    def testQuery(self):
565        query = self.c.query
566        query("drop table if exists test_table")
567        self.addCleanup(query, "drop table test_table")
568        q = "create table test_table (n integer) with oids"
569        r = query(q)
570        self.assertIsNone(r)
571        q = "insert into test_table values (1)"
572        r = query(q)
573        self.assertIsInstance(r, int)
574        q = "insert into test_table select 2"
575        r = query(q)
576        self.assertIsInstance(r, int)
577        oid = r
578        q = "select oid from test_table where n=2"
579        r = query(q).getresult()
580        self.assertEqual(len(r), 1)
581        r = r[0]
582        self.assertEqual(len(r), 1)
583        r = r[0]
584        self.assertIsInstance(r, int)
585        self.assertEqual(r, oid)
586        q = "insert into test_table select 3 union select 4 union select 5"
587        r = query(q)
588        self.assertIsInstance(r, str)
589        self.assertEqual(r, '3')
590        q = "update test_table set n=4 where n<5"
591        r = query(q)
592        self.assertIsInstance(r, str)
593        self.assertEqual(r, '4')
594        q = "delete from test_table"
595        r = query(q)
596        self.assertIsInstance(r, str)
597        self.assertEqual(r, '5')
598
599
600class TestUnicodeQueries(unittest.TestCase):
601    """Test unicode strings as queries via a basic pg connection."""
602
603    def setUp(self):
604        self.c = connect()
605        self.c.query('set client_encoding=utf8')
606
607    def tearDown(self):
608        self.c.close()
609
610    def testGetresulAscii(self):
611        result = u'Hello, world!'
612        q = u"select '%s'" % result
613        v = self.c.query(q).getresult()[0][0]
614        self.assertIsInstance(v, str)
615        self.assertEqual(v, result)
616
617    def testDictresulAscii(self):
618        result = u'Hello, world!'
619        q = u"select '%s' as greeting" % result
620        v = self.c.query(q).dictresult()[0]['greeting']
621        self.assertIsInstance(v, str)
622        self.assertEqual(v, result)
623
624    def testGetresultUtf8(self):
625        result = u'Hello, wörld & ЌОр!'
626        q = u"select '%s'" % result
627        if not unicode_strings:
628            result = result.encode('utf8')
629        # pass the query as unicode
630        try:
631            v = self.c.query(q).getresult()[0][0]
632        except(pg.DataError, pg.NotSupportedError):
633            self.skipTest("database does not support utf8")
634        self.assertIsInstance(v, str)
635        self.assertEqual(v, result)
636        q = q.encode('utf8')
637        # pass the query as bytes
638        v = self.c.query(q).getresult()[0][0]
639        self.assertIsInstance(v, str)
640        self.assertEqual(v, result)
641
642    def testDictresultUtf8(self):
643        result = u'Hello, wörld & ЌОр!'
644        q = u"select '%s' as greeting" % result
645        if not unicode_strings:
646            result = result.encode('utf8')
647        try:
648            v = self.c.query(q).dictresult()[0]['greeting']
649        except (pg.DataError, pg.NotSupportedError):
650            self.skipTest("database does not support utf8")
651        self.assertIsInstance(v, str)
652        self.assertEqual(v, result)
653        q = q.encode('utf8')
654        v = self.c.query(q).dictresult()[0]['greeting']
655        self.assertIsInstance(v, str)
656        self.assertEqual(v, result)
657
658    def testDictresultLatin1(self):
659        try:
660            self.c.query('set client_encoding=latin1')
661        except (pg.DataError, pg.NotSupportedError):
662            self.skipTest("database does not support latin1")
663        result = u'Hello, wörld!'
664        q = u"select '%s'" % result
665        if not unicode_strings:
666            result = result.encode('latin1')
667        v = self.c.query(q).getresult()[0][0]
668        self.assertIsInstance(v, str)
669        self.assertEqual(v, result)
670        q = q.encode('latin1')
671        v = self.c.query(q).getresult()[0][0]
672        self.assertIsInstance(v, str)
673        self.assertEqual(v, result)
674
675    def testDictresultLatin1(self):
676        try:
677            self.c.query('set client_encoding=latin1')
678        except (pg.DataError, pg.NotSupportedError):
679            self.skipTest("database does not support latin1")
680        result = u'Hello, wörld!'
681        q = u"select '%s' as greeting" % result
682        if not unicode_strings:
683            result = result.encode('latin1')
684        v = self.c.query(q).dictresult()[0]['greeting']
685        self.assertIsInstance(v, str)
686        self.assertEqual(v, result)
687        q = q.encode('latin1')
688        v = self.c.query(q).dictresult()[0]['greeting']
689        self.assertIsInstance(v, str)
690        self.assertEqual(v, result)
691
692    def testGetresultCyrillic(self):
693        try:
694            self.c.query('set client_encoding=iso_8859_5')
695        except (pg.DataError, pg.NotSupportedError):
696            self.skipTest("database does not support cyrillic")
697        result = u'Hello, ЌОр!'
698        q = u"select '%s'" % result
699        if not unicode_strings:
700            result = result.encode('cyrillic')
701        v = self.c.query(q).getresult()[0][0]
702        self.assertIsInstance(v, str)
703        self.assertEqual(v, result)
704        q = q.encode('cyrillic')
705        v = self.c.query(q).getresult()[0][0]
706        self.assertIsInstance(v, str)
707        self.assertEqual(v, result)
708
709    def testDictresultCyrillic(self):
710        try:
711            self.c.query('set client_encoding=iso_8859_5')
712        except (pg.DataError, pg.NotSupportedError):
713            self.skipTest("database does not support cyrillic")
714        result = u'Hello, ЌОр!'
715        q = u"select '%s' as greeting" % result
716        if not unicode_strings:
717            result = result.encode('cyrillic')
718        v = self.c.query(q).dictresult()[0]['greeting']
719        self.assertIsInstance(v, str)
720        self.assertEqual(v, result)
721        q = q.encode('cyrillic')
722        v = self.c.query(q).dictresult()[0]['greeting']
723        self.assertIsInstance(v, str)
724        self.assertEqual(v, result)
725
726    def testGetresultLatin9(self):
727        try:
728            self.c.query('set client_encoding=latin9')
729        except (pg.DataError, pg.NotSupportedError):
730            self.skipTest("database does not support latin9")
731        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
732        q = u"select '%s'" % result
733        if not unicode_strings:
734            result = result.encode('latin9')
735        v = self.c.query(q).getresult()[0][0]
736        self.assertIsInstance(v, str)
737        self.assertEqual(v, result)
738        q = q.encode('latin9')
739        v = self.c.query(q).getresult()[0][0]
740        self.assertIsInstance(v, str)
741        self.assertEqual(v, result)
742
743    def testDictresultLatin9(self):
744        try:
745            self.c.query('set client_encoding=latin9')
746        except (pg.DataError, pg.NotSupportedError):
747            self.skipTest("database does not support latin9")
748        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
749        q = u"select '%s' as menu" % result
750        if not unicode_strings:
751            result = result.encode('latin9')
752        v = self.c.query(q).dictresult()[0]['menu']
753        self.assertIsInstance(v, str)
754        self.assertEqual(v, result)
755        q = q.encode('latin9')
756        v = self.c.query(q).dictresult()[0]['menu']
757        self.assertIsInstance(v, str)
758        self.assertEqual(v, result)
759
760
761class TestParamQueries(unittest.TestCase):
762    """Test queries with parameters via a basic pg connection."""
763
764    def setUp(self):
765        self.c = connect()
766        self.c.query('set client_encoding=utf8')
767
768    def tearDown(self):
769        self.c.close()
770
771    def testQueryWithNoneParam(self):
772        self.assertRaises(TypeError, self.c.query, "select $1", None)
773        self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None)
774        self.assertEqual(self.c.query("select $1::integer", (None,)
775            ).getresult(), [(None,)])
776        self.assertEqual(self.c.query("select $1::text", [None]
777            ).getresult(), [(None,)])
778        self.assertEqual(self.c.query("select $1::text", [[None]]
779            ).getresult(), [(None,)])
780
781    def testQueryWithBoolParams(self, bool_enabled=None):
782        query = self.c.query
783        if bool_enabled is not None:
784            bool_enabled_default = pg.get_bool()
785            pg.set_bool(bool_enabled)
786        try:
787            bool_on = bool_enabled or bool_enabled is None
788            v_false, v_true = (False, True) if bool_on else 'ft'
789            r_false, r_true = [(v_false,)], [(v_true,)]
790            self.assertEqual(query("select false").getresult(), r_false)
791            self.assertEqual(query("select true").getresult(), r_true)
792            q = "select $1::bool"
793            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
794            self.assertEqual(query(q, ('f',)).getresult(), r_false)
795            self.assertEqual(query(q, ('t',)).getresult(), r_true)
796            self.assertEqual(query(q, ('false',)).getresult(), r_false)
797            self.assertEqual(query(q, ('true',)).getresult(), r_true)
798            self.assertEqual(query(q, ('n',)).getresult(), r_false)
799            self.assertEqual(query(q, ('y',)).getresult(), r_true)
800            self.assertEqual(query(q, (0,)).getresult(), r_false)
801            self.assertEqual(query(q, (1,)).getresult(), r_true)
802            self.assertEqual(query(q, (False,)).getresult(), r_false)
803            self.assertEqual(query(q, (True,)).getresult(), r_true)
804        finally:
805            if bool_enabled is not None:
806                pg.set_bool(bool_enabled_default)
807
808    def testQueryWithBoolParamsNotDefault(self):
809        self.testQueryWithBoolParams(bool_enabled=not pg.get_bool())
810
811    def testQueryWithIntParams(self):
812        query = self.c.query
813        self.assertEqual(query("select 1+1").getresult(), [(2,)])
814        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
815        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
816        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
817        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
818        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
819            [(Decimal('2'),)])
820        self.assertEqual(query("select 1, $1::integer", (2,)
821            ).getresult(), [(1, 2)])
822        self.assertEqual(query("select 1 union select $1::integer", (2,)
823            ).getresult(), [(1,), (2,)])
824        self.assertEqual(query("select $1::integer+$2", (1, 2)
825            ).getresult(), [(3,)])
826        self.assertEqual(query("select $1::integer+$2", [1, 2]
827            ).getresult(), [(3,)])
828        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))
829            ).getresult(), [(15,)])
830
831    def testQueryWithStrParams(self):
832        query = self.c.query
833        self.assertEqual(query("select $1||', world!'", ('Hello',)
834            ).getresult(), [('Hello, world!',)])
835        self.assertEqual(query("select $1||', world!'", ['Hello']
836            ).getresult(), [('Hello, world!',)])
837        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
838            ).getresult(), [('Hello, world!',)])
839        self.assertEqual(query("select $1::text", ('Hello, world!',)
840            ).getresult(), [('Hello, world!',)])
841        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
842            ).getresult(), [('Hello', 'world')])
843        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
844            ).getresult(), [('Hello', 'world')])
845        self.assertEqual(query("select $1::text union select $2::text",
846            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
847        try:
848            query("select 'wörld'")
849        except (pg.DataError, pg.NotSupportedError):
850            self.skipTest('database does not support utf8')
851        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
852            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
853
854    def testQueryWithUnicodeParams(self):
855        query = self.c.query
856        try:
857            query('set client_encoding=utf8')
858            query("select 'wörld'").getresult()[0][0] == 'wörld'
859        except (pg.DataError, pg.NotSupportedError):
860            self.skipTest("database does not support utf8")
861        self.assertEqual(query("select $1||', '||$2||'!'",
862            ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)])
863
864    def testQueryWithUnicodeParamsLatin1(self):
865        query = self.c.query
866        try:
867            query('set client_encoding=latin1')
868            query("select 'wörld'").getresult()[0][0] == 'wörld'
869        except (pg.DataError, pg.NotSupportedError):
870            self.skipTest("database does not support latin1")
871        r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult()
872        if unicode_strings:
873            self.assertEqual(r, [('Hello, wörld!',)])
874        else:
875            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
876        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
877            ('Hello', u'ЌОр'))
878        query('set client_encoding=iso_8859_1')
879        r = query("select $1||', '||$2||'!'",
880            ('Hello', u'wörld')).getresult()
881        if unicode_strings:
882            self.assertEqual(r, [('Hello, wörld!',)])
883        else:
884            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
885        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
886            ('Hello', u'ЌОр'))
887        query('set client_encoding=sql_ascii')
888        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
889            ('Hello', u'wörld'))
890
891    def testQueryWithUnicodeParamsCyrillic(self):
892        query = self.c.query
893        try:
894            query('set client_encoding=iso_8859_5')
895            query("select 'ЌОр'").getresult()[0][0] == 'ЌОр'
896        except (pg.DataError, pg.NotSupportedError):
897            self.skipTest("database does not support cyrillic")
898        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
899            ('Hello', u'wörld'))
900        r = query("select $1||', '||$2||'!'",
901            ('Hello', u'ЌОр')).getresult()
902        if unicode_strings:
903            self.assertEqual(r, [('Hello, ЌОр!',)])
904        else:
905            self.assertEqual(r, [(u'Hello, ЌОр!'.encode('cyrillic'),)])
906        query('set client_encoding=sql_ascii')
907        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
908            ('Hello', u'ЌОр!'))
909
910    def testQueryWithMixedParams(self):
911        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
912            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
913        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
914            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
915
916    def testQueryWithDuplicateParams(self):
917        self.assertRaises(pg.ProgrammingError,
918            self.c.query, "select $1+$1", (1,))
919        self.assertRaises(pg.ProgrammingError,
920            self.c.query, "select $1+$1", (1, 2))
921
922    def testQueryWithZeroParams(self):
923        self.assertEqual(self.c.query("select 1+1", []
924            ).getresult(), [(2,)])
925
926    def testQueryWithGarbage(self):
927        garbage = r"'\{}+()-#[]oo324"
928        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
929            ).dictresult(), [{'garbage': garbage}])
930
931
932class TestQueryResultTypes(unittest.TestCase):
933    """Test proper result types via a basic pg connection."""
934
935    def setUp(self):
936        self.c = connect()
937        self.c.query('set client_encoding=utf8')
938        self.c.query("set datestyle='ISO,YMD'")
939        self.c.query("set timezone='UTC'")
940
941    def tearDown(self):
942        self.c.close()
943
944    def assert_proper_cast(self, value, pgtype, pytype):
945        q = 'select $1::%s' % (pgtype,)
946        try:
947            r = self.c.query(q, (value,)).getresult()[0][0]
948        except pg.ProgrammingError:
949            if pgtype in ('json', 'jsonb'):
950                self.skipTest('database does not support json')
951        self.assertIsInstance(r, pytype)
952        if isinstance(value, str):
953            if not value or ' ' in value or '{' in value:
954                value = '"%s"' % value
955        value = '{%s}' % value
956        r = self.c.query(q + '[]', (value,)).getresult()[0][0]
957        if pgtype.startswith(('date', 'time', 'interval')):
958            # arrays of these are casted by the DB wrapper only
959            self.assertEqual(r, value)
960        else:
961            self.assertIsInstance(r, list)
962            self.assertEqual(len(r), 1)
963            self.assertIsInstance(r[0], pytype)
964
965    def testInt(self):
966        self.assert_proper_cast(0, 'int', int)
967        self.assert_proper_cast(0, 'smallint', int)
968        self.assert_proper_cast(0, 'oid', int)
969        self.assert_proper_cast(0, 'cid', int)
970        self.assert_proper_cast(0, 'xid', int)
971
972    def testLong(self):
973        self.assert_proper_cast(0, 'bigint', long)
974
975    def testFloat(self):
976        self.assert_proper_cast(0, 'float', float)
977        self.assert_proper_cast(0, 'real', float)
978        self.assert_proper_cast(0, 'double', float)
979        self.assert_proper_cast(0, 'double precision', float)
980        self.assert_proper_cast('infinity', 'float', float)
981
982    def testFloat(self):
983        decimal = pg.get_decimal()
984        self.assert_proper_cast(decimal(0), 'numeric', decimal)
985        self.assert_proper_cast(decimal(0), 'decimal', decimal)
986
987    def testMoney(self):
988        decimal = pg.get_decimal()
989        self.assert_proper_cast(decimal('0'), 'money', decimal)
990
991    def testBool(self):
992        bool_type = bool if pg.get_bool() else str
993        self.assert_proper_cast('f', 'bool', bool_type)
994
995    def testDate(self):
996        self.assert_proper_cast('1956-01-31', 'date', str)
997        self.assert_proper_cast('10:20:30', 'interval', str)
998        self.assert_proper_cast('08:42:15', 'time', str)
999        self.assert_proper_cast('08:42:15+00', 'timetz', str)
1000        self.assert_proper_cast('1956-01-31 08:42:15', 'timestamp', str)
1001        self.assert_proper_cast('1956-01-31 08:42:15+00', 'timestamptz', str)
1002
1003    def testText(self):
1004        self.assert_proper_cast('', 'text', str)
1005        self.assert_proper_cast('', 'char', str)
1006        self.assert_proper_cast('', 'bpchar', str)
1007        self.assert_proper_cast('', 'varchar', str)
1008
1009    def testBytea(self):
1010        self.assert_proper_cast('', 'bytea', bytes)
1011
1012    def testJson(self):
1013        self.assert_proper_cast('{}', 'json', dict)
1014
1015
1016class TestInserttable(unittest.TestCase):
1017    """Test inserttable method."""
1018
1019    cls_set_up = False
1020
1021    @classmethod
1022    def setUpClass(cls):
1023        c = connect()
1024        c.query("drop table if exists test cascade")
1025        c.query("create table test ("
1026            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
1027            "d numeric, f4 real, f8 double precision, m money,"
1028            "c char(1), v4 varchar(4), c4 char(4), t text)")
1029        # Check whether the test database uses SQL_ASCII - this means
1030        # that it does not consider encoding when calculating lengths.
1031        c.query("set client_encoding=utf8")
1032        try:
1033            c.query("select 'À'")
1034        except (pg.DataError, pg.NotSupportedError):
1035            cls.has_encoding = False
1036        else:
1037            cls.has_encoding = c.query(
1038                "select length('À') - length('a')").getresult()[0][0] == 0
1039        c.close()
1040        cls.cls_set_up = True
1041
1042    @classmethod
1043    def tearDownClass(cls):
1044        c = connect()
1045        c.query("drop table test cascade")
1046        c.close()
1047
1048    def setUp(self):
1049        self.assertTrue(self.cls_set_up)
1050        self.c = connect()
1051        self.c.query("set client_encoding=utf8")
1052        self.c.query("set datestyle='ISO,YMD'")
1053        self.c.query("set lc_monetary='C'")
1054
1055    def tearDown(self):
1056        self.c.query("truncate table test")
1057        self.c.close()
1058
1059    data = [
1060        (-1, -1, long(-1), True, '1492-10-12', '08:30:00',
1061            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
1062        (0, 0, long(0), False, '1607-04-14', '09:00:00',
1063            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
1064        (1, 1, long(1), True, '1801-03-04', '03:45:00',
1065            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
1066        (2, 2, long(2), False, '1903-12-17', '11:22:00',
1067            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
1068
1069    @classmethod
1070    def db_len(cls, s, encoding):
1071        if cls.has_encoding:
1072            s = s if isinstance(s, unicode) else s.decode(encoding)
1073        else:
1074            s = s.encode(encoding) if isinstance(s, unicode) else s
1075        return len(s)
1076
1077    def get_back(self, encoding='utf-8'):
1078        """Convert boolean and decimal values back."""
1079        data = []
1080        for row in self.c.query("select * from test order by 1").getresult():
1081            self.assertIsInstance(row, tuple)
1082            row = list(row)
1083            if row[0] is not None:  # smallint
1084                self.assertIsInstance(row[0], int)
1085            if row[1] is not None:  # integer
1086                self.assertIsInstance(row[1], int)
1087            if row[2] is not None:  # bigint
1088                self.assertIsInstance(row[2], long)
1089            if row[3] is not None:  # boolean
1090                self.assertIsInstance(row[3], bool)
1091            if row[4] is not None:  # date
1092                self.assertIsInstance(row[4], str)
1093                self.assertTrue(row[4].replace('-', '').isdigit())
1094            if row[5] is not None:  # time
1095                self.assertIsInstance(row[5], str)
1096                self.assertTrue(row[5].replace(':', '').isdigit())
1097            if row[6] is not None:  # numeric
1098                self.assertIsInstance(row[6], Decimal)
1099                row[6] = float(row[6])
1100            if row[7] is not None:  # real
1101                self.assertIsInstance(row[7], float)
1102            if row[8] is not None:  # double precision
1103                self.assertIsInstance(row[8], float)
1104                row[8] = float(row[8])
1105            if row[9] is not None:  # money
1106                self.assertIsInstance(row[9], Decimal)
1107                row[9] = str(float(row[9]))
1108            if row[10] is not None:  # char(1)
1109                self.assertIsInstance(row[10], str)
1110                self.assertEqual(self.db_len(row[10], encoding), 1)
1111            if row[11] is not None:  # varchar(4)
1112                self.assertIsInstance(row[11], str)
1113                self.assertLessEqual(self.db_len(row[11], encoding), 4)
1114            if row[12] is not None:  # char(4)
1115                self.assertIsInstance(row[12], str)
1116                self.assertEqual(self.db_len(row[12], encoding), 4)
1117                row[12] = row[12].rstrip()
1118            if row[13] is not None:  # text
1119                self.assertIsInstance(row[13], str)
1120            row = tuple(row)
1121            data.append(row)
1122        return data
1123
1124    def testInserttable1Row(self):
1125        data = self.data[2:3]
1126        self.c.inserttable('test', data)
1127        self.assertEqual(self.get_back(), data)
1128
1129    def testInserttable4Rows(self):
1130        data = self.data
1131        self.c.inserttable('test', data)
1132        self.assertEqual(self.get_back(), data)
1133
1134    def testInserttableMultipleRows(self):
1135        num_rows = 100
1136        data = self.data[2:3] * num_rows
1137        self.c.inserttable('test', data)
1138        r = self.c.query("select count(*) from test").getresult()[0][0]
1139        self.assertEqual(r, num_rows)
1140
1141    def testInserttableMultipleCalls(self):
1142        num_rows = 10
1143        data = self.data[2:3]
1144        for _i in range(num_rows):
1145            self.c.inserttable('test', data)
1146        r = self.c.query("select count(*) from test").getresult()[0][0]
1147        self.assertEqual(r, num_rows)
1148
1149    def testInserttableNullValues(self):
1150        data = [(None,) * 14] * 100
1151        self.c.inserttable('test', data)
1152        self.assertEqual(self.get_back(), data)
1153
1154    def testInserttableMaxValues(self):
1155        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
1156            True, '2999-12-31', '11:59:59', 1e99,
1157            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
1158            "1", "1234", "1234", "1234" * 100)]
1159        self.c.inserttable('test', data)
1160        self.assertEqual(self.get_back(), data)
1161
1162    def testInserttableByteValues(self):
1163        try:
1164            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1165        except pg.DataError:
1166            self.skipTest("database does not support utf8")
1167        # non-ascii chars do not fit in char(1) when there is no encoding
1168        c = u'€' if self.has_encoding else u'$'
1169        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1170            0.0, 0.0, 0.0, u'0.0',
1171            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1172        row_bytes = tuple(s.encode('utf-8')
1173            if isinstance(s, unicode) else s for s in row_unicode)
1174        data = [row_bytes] * 2
1175        self.c.inserttable('test', data)
1176        if unicode_strings:
1177            data = [row_unicode] * 2
1178        self.assertEqual(self.get_back(), data)
1179
1180    def testInserttableUnicodeUtf8(self):
1181        try:
1182            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1183        except pg.DataError:
1184            self.skipTest("database does not support utf8")
1185        # non-ascii chars do not fit in char(1) when there is no encoding
1186        c = u'€' if self.has_encoding else u'$'
1187        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1188            0.0, 0.0, 0.0, u'0.0',
1189            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1190        data = [row_unicode] * 2
1191        self.c.inserttable('test', data)
1192        if not unicode_strings:
1193            row_bytes = tuple(s.encode('utf-8')
1194                if isinstance(s, unicode) else s for s in row_unicode)
1195            data = [row_bytes] * 2
1196        self.assertEqual(self.get_back(), data)
1197
1198    def testInserttableUnicodeLatin1(self):
1199
1200        try:
1201            self.c.query("set client_encoding=latin1")
1202            self.c.query("select 'Â¥'")
1203        except (pg.DataError, pg.NotSupportedError):
1204            self.skipTest("database does not support latin1")
1205        # non-ascii chars do not fit in char(1) when there is no encoding
1206        c = u'€' if self.has_encoding else u'$'
1207        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1208            0.0, 0.0, 0.0, u'0.0',
1209            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1210        data = [row_unicode]
1211        # cannot encode € sign with latin1 encoding
1212        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1213        row_unicode = tuple(s.replace(u'€', u'Â¥')
1214            if isinstance(s, unicode) else s for s in row_unicode)
1215        data = [row_unicode] * 2
1216        self.c.inserttable('test', data)
1217        if not unicode_strings:
1218            row_bytes = tuple(s.encode('latin1')
1219                if isinstance(s, unicode) else s for s in row_unicode)
1220            data = [row_bytes] * 2
1221        self.assertEqual(self.get_back('latin1'), data)
1222
1223    def testInserttableUnicodeLatin9(self):
1224        try:
1225            self.c.query("set client_encoding=latin9")
1226            self.c.query("select '€'")
1227        except (pg.DataError, pg.NotSupportedError):
1228            self.skipTest("database does not support latin9")
1229            return
1230        # non-ascii chars do not fit in char(1) when there is no encoding
1231        c = u'€' if self.has_encoding else u'$'
1232        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1233            0.0, 0.0, 0.0, u'0.0',
1234            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1235        data = [row_unicode] * 2
1236        self.c.inserttable('test', data)
1237        if not unicode_strings:
1238            row_bytes = tuple(s.encode('latin9')
1239                if isinstance(s, unicode) else s for s in row_unicode)
1240            data = [row_bytes] * 2
1241        self.assertEqual(self.get_back('latin9'), data)
1242
1243    def testInserttableNoEncoding(self):
1244        self.c.query("set client_encoding=sql_ascii")
1245        # non-ascii chars do not fit in char(1) when there is no encoding
1246        c = u'€' if self.has_encoding else u'$'
1247        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1248            0.0, 0.0, 0.0, u'0.0',
1249            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1250        data = [row_unicode]
1251        # cannot encode non-ascii unicode without a specific encoding
1252        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1253
1254
1255class TestDirectSocketAccess(unittest.TestCase):
1256    """Test copy command with direct socket access."""
1257
1258    cls_set_up = False
1259
1260    @classmethod
1261    def setUpClass(cls):
1262        c = connect()
1263        c.query("drop table if exists test cascade")
1264        c.query("create table test (i int, v varchar(16))")
1265        c.close()
1266        cls.cls_set_up = True
1267
1268    @classmethod
1269    def tearDownClass(cls):
1270        c = connect()
1271        c.query("drop table test cascade")
1272        c.close()
1273
1274    def setUp(self):
1275        self.assertTrue(self.cls_set_up)
1276        self.c = connect()
1277        self.c.query("set client_encoding=utf8")
1278
1279    def tearDown(self):
1280        self.c.query("truncate table test")
1281        self.c.close()
1282
1283    def testPutline(self):
1284        putline = self.c.putline
1285        query = self.c.query
1286        data = list(enumerate("apple pear plum cherry banana".split()))
1287        query("copy test from stdin")
1288        try:
1289            for i, v in data:
1290                putline("%d\t%s\n" % (i, v))
1291            putline("\\.\n")
1292        finally:
1293            self.c.endcopy()
1294        r = query("select * from test").getresult()
1295        self.assertEqual(r, data)
1296
1297    def testPutlineBytesAndUnicode(self):
1298        putline = self.c.putline
1299        query = self.c.query
1300        try:
1301            query("select 'kÀse+wÃŒrstel'")
1302        except (pg.DataError, pg.NotSupportedError):
1303            self.skipTest('database does not support utf8')
1304        query("copy test from stdin")
1305        try:
1306            putline(u"47\tkÀse\n".encode('utf8'))
1307            putline("35\twÃŒrstel\n")
1308            putline(b"\\.\n")
1309        finally:
1310            self.c.endcopy()
1311        r = query("select * from test").getresult()
1312        self.assertEqual(r, [(47, 'kÀse'), (35, 'wÃŒrstel')])
1313
1314    def testGetline(self):
1315        getline = self.c.getline
1316        query = self.c.query
1317        data = list(enumerate("apple banana pear plum strawberry".split()))
1318        n = len(data)
1319        self.c.inserttable('test', data)
1320        query("copy test to stdout")
1321        try:
1322            for i in range(n + 2):
1323                v = getline()
1324                if i < n:
1325                    self.assertEqual(v, '%d\t%s' % data[i])
1326                elif i == n:
1327                    self.assertEqual(v, '\\.')
1328                else:
1329                    self.assertIsNone(v)
1330        finally:
1331            try:
1332                self.c.endcopy()
1333            except IOError:
1334                pass
1335
1336    def testGetlineBytesAndUnicode(self):
1337        getline = self.c.getline
1338        query = self.c.query
1339        try:
1340            query("select 'kÀse+wÃŒrstel'")
1341        except (pg.DataError, pg.NotSupportedError):
1342            self.skipTest('database does not support utf8')
1343        data = [(54, u'kÀse'.encode('utf8')), (73, u'wÃŒrstel')]
1344        self.c.inserttable('test', data)
1345        query("copy test to stdout")
1346        try:
1347            v = getline()
1348            self.assertIsInstance(v, str)
1349            self.assertEqual(v, '54\tkÀse')
1350            v = getline()
1351            self.assertIsInstance(v, str)
1352            self.assertEqual(v, '73\twÃŒrstel')
1353            self.assertEqual(getline(), '\\.')
1354            self.assertIsNone(getline())
1355        finally:
1356            try:
1357                self.c.endcopy()
1358            except IOError:
1359                pass
1360
1361    def testParameterChecks(self):
1362        self.assertRaises(TypeError, self.c.putline)
1363        self.assertRaises(TypeError, self.c.getline, 'invalid')
1364        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
1365
1366
1367class TestNotificatons(unittest.TestCase):
1368    """Test notification support."""
1369
1370    def setUp(self):
1371        self.c = connect()
1372
1373    def tearDown(self):
1374        self.doCleanups()
1375        self.c.close()
1376
1377    def testGetNotify(self):
1378        getnotify = self.c.getnotify
1379        query = self.c.query
1380        self.assertIsNone(getnotify())
1381        query('listen test_notify')
1382        try:
1383            self.assertIsNone(self.c.getnotify())
1384            query("notify test_notify")
1385            r = getnotify()
1386            self.assertIsInstance(r, tuple)
1387            self.assertEqual(len(r), 3)
1388            self.assertIsInstance(r[0], str)
1389            self.assertIsInstance(r[1], int)
1390            self.assertIsInstance(r[2], str)
1391            self.assertEqual(r[0], 'test_notify')
1392            self.assertEqual(r[2], '')
1393            self.assertIsNone(self.c.getnotify())
1394            query("notify test_notify, 'test_payload'")
1395            r = getnotify()
1396            self.assertTrue(isinstance(r, tuple))
1397            self.assertEqual(len(r), 3)
1398            self.assertIsInstance(r[0], str)
1399            self.assertIsInstance(r[1], int)
1400            self.assertIsInstance(r[2], str)
1401            self.assertEqual(r[0], 'test_notify')
1402            self.assertEqual(r[2], 'test_payload')
1403            self.assertIsNone(getnotify())
1404        finally:
1405            query('unlisten test_notify')
1406
1407    def testGetNoticeReceiver(self):
1408        self.assertIsNone(self.c.get_notice_receiver())
1409
1410    def testSetNoticeReceiver(self):
1411        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
1412        self.assertRaises(TypeError, self.c.set_notice_receiver, 'invalid')
1413        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
1414        self.assertIsNone(self.c.set_notice_receiver(None))
1415
1416    def testSetAndGetNoticeReceiver(self):
1417        r = lambda notice: None
1418        self.assertIsNone(self.c.set_notice_receiver(r))
1419        self.assertIs(self.c.get_notice_receiver(), r)
1420        self.assertIsNone(self.c.set_notice_receiver(None))
1421        self.assertIsNone(self.c.get_notice_receiver())
1422
1423    def testNoticeReceiver(self):
1424        self.addCleanup(self.c.query, 'drop function bilbo_notice();')
1425        self.c.query('''create function bilbo_notice() returns void AS $$
1426            begin
1427                raise warning 'Bilbo was here!';
1428            end;
1429            $$ language plpgsql''')
1430        received = {}
1431
1432        def notice_receiver(notice):
1433            for attr in dir(notice):
1434                if attr.startswith('__'):
1435                    continue
1436                value = getattr(notice, attr)
1437                if isinstance(value, str):
1438                    value = value.replace('WARNUNG', 'WARNING')
1439                received[attr] = value
1440
1441        self.c.set_notice_receiver(notice_receiver)
1442        self.c.query('select bilbo_notice()')
1443        self.assertEqual(received, dict(
1444            pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
1445            severity='WARNING', primary='Bilbo was here!',
1446            detail=None, hint=None))
1447
1448
1449class TestConfigFunctions(unittest.TestCase):
1450    """Test the functions for changing default settings.
1451
1452    To test the effect of most of these functions, we need a database
1453    connection.  That's why they are covered in this test module.
1454    """
1455
1456    def setUp(self):
1457        self.c = connect()
1458        self.c.query("set client_encoding=utf8")
1459        self.c.query('set bytea_output=hex')
1460        self.c.query("set lc_monetary='C'")
1461
1462    def tearDown(self):
1463        self.c.close()
1464
1465    def testGetDecimalPoint(self):
1466        point = pg.get_decimal_point()
1467        # error if a parameter is passed
1468        self.assertRaises(TypeError, pg.get_decimal_point, point)
1469        self.assertIsInstance(point, str)
1470        self.assertEqual(point, '.')  # the default setting
1471        pg.set_decimal_point(',')
1472        try:
1473            r = pg.get_decimal_point()
1474        finally:
1475            pg.set_decimal_point(point)
1476        self.assertIsInstance(r, str)
1477        self.assertEqual(r, ',')
1478        pg.set_decimal_point("'")
1479        try:
1480            r = pg.get_decimal_point()
1481        finally:
1482            pg.set_decimal_point(point)
1483        self.assertIsInstance(r, str)
1484        self.assertEqual(r, "'")
1485        pg.set_decimal_point('')
1486        try:
1487            r = pg.get_decimal_point()
1488        finally:
1489            pg.set_decimal_point(point)
1490        self.assertIsNone(r)
1491        pg.set_decimal_point(None)
1492        try:
1493            r = pg.get_decimal_point()
1494        finally:
1495            pg.set_decimal_point(point)
1496        self.assertIsNone(r)
1497
1498    def testSetDecimalPoint(self):
1499        d = pg.Decimal
1500        point = pg.get_decimal_point()
1501        self.assertRaises(TypeError, pg.set_decimal_point)
1502        # error if decimal point is not a string
1503        self.assertRaises(TypeError, pg.set_decimal_point, 0)
1504        # error if more than one decimal point passed
1505        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
1506        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
1507        # error if decimal point is not a punctuation character
1508        self.assertRaises(TypeError, pg.set_decimal_point, '0')
1509        query = self.c.query
1510        # check that money values are interpreted as decimal values
1511        # only if decimal_point is set, and that the result is correct
1512        # only if it is set suitable for the current lc_monetary setting
1513        select_money = "select '34.25'::money"
1514        proper_money = d('34.25')
1515        bad_money = d('3425')
1516        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
1517        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
1518        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
1519        de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25',
1520            'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM')
1521        # first try with English localization (using the point)
1522        for lc in en_locales:
1523            try:
1524                query("set lc_monetary='%s'" % lc)
1525            except pg.DataError:
1526                pass
1527            else:
1528                break
1529        else:
1530            self.skipTest("cannot set English money locale")
1531        try:
1532            r = query(select_money)
1533        except pg.DataError:
1534            # this can happen if the currency signs cannot be
1535            # converted using the encoding of the test database
1536            self.skipTest("database does not support English money")
1537        pg.set_decimal_point(None)
1538        try:
1539            r = r.getresult()[0][0]
1540        finally:
1541            pg.set_decimal_point(point)
1542        self.assertIsInstance(r, str)
1543        self.assertIn(r, en_money)
1544        r = query(select_money)
1545        pg.set_decimal_point('')
1546        try:
1547            r = r.getresult()[0][0]
1548        finally:
1549            pg.set_decimal_point(point)
1550        self.assertIsInstance(r, str)
1551        self.assertIn(r, en_money)
1552        r = query(select_money)
1553        pg.set_decimal_point('.')
1554        try:
1555            r = r.getresult()[0][0]
1556        finally:
1557            pg.set_decimal_point(point)
1558        self.assertIsInstance(r, d)
1559        self.assertEqual(r, proper_money)
1560        r = query(select_money)
1561        pg.set_decimal_point(',')
1562        try:
1563            r = r.getresult()[0][0]
1564        finally:
1565            pg.set_decimal_point(point)
1566        self.assertIsInstance(r, d)
1567        self.assertEqual(r, bad_money)
1568        r = query(select_money)
1569        pg.set_decimal_point("'")
1570        try:
1571            r = r.getresult()[0][0]
1572        finally:
1573            pg.set_decimal_point(point)
1574        self.assertIsInstance(r, d)
1575        self.assertEqual(r, bad_money)
1576        # then try with German localization (using the comma)
1577        for lc in de_locales:
1578            try:
1579                query("set lc_monetary='%s'" % lc)
1580            except pg.DataError:
1581                pass
1582            else:
1583                break
1584        else:
1585            self.skipTest("cannot set German money locale")
1586        select_money = select_money.replace('.', ',')
1587        try:
1588            r = query(select_money)
1589        except pg.DataError:
1590            self.skipTest("database does not support English money")
1591        pg.set_decimal_point(None)
1592        try:
1593            r = r.getresult()[0][0]
1594        finally:
1595            pg.set_decimal_point(point)
1596        self.assertIsInstance(r, str)
1597        self.assertIn(r, de_money)
1598        r = query(select_money)
1599        pg.set_decimal_point('')
1600        try:
1601            r = r.getresult()[0][0]
1602        finally:
1603            pg.set_decimal_point(point)
1604        self.assertIsInstance(r, str)
1605        self.assertIn(r, de_money)
1606        r = query(select_money)
1607        pg.set_decimal_point(',')
1608        try:
1609            r = r.getresult()[0][0]
1610        finally:
1611            pg.set_decimal_point(point)
1612        self.assertIsInstance(r, d)
1613        self.assertEqual(r, proper_money)
1614        r = query(select_money)
1615        pg.set_decimal_point('.')
1616        try:
1617            r = r.getresult()[0][0]
1618        finally:
1619            pg.set_decimal_point(point)
1620        self.assertEqual(r, bad_money)
1621        r = query(select_money)
1622        pg.set_decimal_point("'")
1623        try:
1624            r = r.getresult()[0][0]
1625        finally:
1626            pg.set_decimal_point(point)
1627        self.assertEqual(r, bad_money)
1628
1629    def testGetDecimal(self):
1630        decimal_class = pg.get_decimal()
1631        # error if a parameter is passed
1632        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
1633        self.assertIs(decimal_class, pg.Decimal)  # the default setting
1634        pg.set_decimal(int)
1635        try:
1636            r = pg.get_decimal()
1637        finally:
1638            pg.set_decimal(decimal_class)
1639        self.assertIs(r, int)
1640        r = pg.get_decimal()
1641        self.assertIs(r, decimal_class)
1642
1643    def testSetDecimal(self):
1644        decimal_class = pg.get_decimal()
1645        # error if no parameter is passed
1646        self.assertRaises(TypeError, pg.set_decimal)
1647        query = self.c.query
1648        try:
1649            r = query("select 3425::numeric")
1650        except pg.DatabaseError:
1651            self.skipTest('database does not support numeric')
1652        r = r.getresult()[0][0]
1653        self.assertIsInstance(r, decimal_class)
1654        self.assertEqual(r, decimal_class('3425'))
1655        r = query("select 3425::numeric")
1656        pg.set_decimal(int)
1657        try:
1658            r = r.getresult()[0][0]
1659        finally:
1660            pg.set_decimal(decimal_class)
1661        self.assertNotIsInstance(r, decimal_class)
1662        self.assertIsInstance(r, int)
1663        self.assertEqual(r, int(3425))
1664
1665    def testGetBool(self):
1666        use_bool = pg.get_bool()
1667        # error if a parameter is passed
1668        self.assertRaises(TypeError, pg.get_bool, use_bool)
1669        self.assertIsInstance(use_bool, bool)
1670        self.assertIs(use_bool, True)  # the default setting
1671        pg.set_bool(False)
1672        try:
1673            r = pg.get_bool()
1674        finally:
1675            pg.set_bool(use_bool)
1676        self.assertIsInstance(r, bool)
1677        self.assertIs(r, False)
1678        pg.set_bool(True)
1679        try:
1680            r = pg.get_bool()
1681        finally:
1682            pg.set_bool(use_bool)
1683        self.assertIsInstance(r, bool)
1684        self.assertIs(r, True)
1685        pg.set_bool(0)
1686        try:
1687            r = pg.get_bool()
1688        finally:
1689            pg.set_bool(use_bool)
1690        self.assertIsInstance(r, bool)
1691        self.assertIs(r, False)
1692        pg.set_bool(1)
1693        try:
1694            r = pg.get_bool()
1695        finally:
1696            pg.set_bool(use_bool)
1697        self.assertIsInstance(r, bool)
1698        self.assertIs(r, True)
1699
1700    def testSetBool(self):
1701        use_bool = pg.get_bool()
1702        # error if no parameter is passed
1703        self.assertRaises(TypeError, pg.set_bool)
1704        query = self.c.query
1705        try:
1706            r = query("select true::bool")
1707        except pg.ProgrammingError:
1708            self.skipTest('database does not support bool')
1709        r = r.getresult()[0][0]
1710        self.assertIsInstance(r, bool)
1711        self.assertEqual(r, True)
1712        r = query("select true::bool")
1713        pg.set_bool(False)
1714        try:
1715            r = r.getresult()[0][0]
1716        finally:
1717            pg.set_bool(use_bool)
1718        self.assertIsInstance(r, str)
1719        self.assertIs(r, 't')
1720        r = query("select true::bool")
1721        pg.set_bool(True)
1722        try:
1723            r = r.getresult()[0][0]
1724        finally:
1725            pg.set_bool(use_bool)
1726        self.assertIsInstance(r, bool)
1727        self.assertIs(r, True)
1728
1729    def testGetByteEscaped(self):
1730        bytea_escaped = pg.get_bytea_escaped()
1731        # error if a parameter is passed
1732        self.assertRaises(TypeError, pg.get_bytea_escaped, bytea_escaped)
1733        self.assertIsInstance(bytea_escaped, bool)
1734        self.assertIs(bytea_escaped, False)  # the default setting
1735        pg.set_bytea_escaped(True)
1736        try:
1737            r = pg.get_bytea_escaped()
1738        finally:
1739            pg.set_bytea_escaped(bytea_escaped)
1740        self.assertIsInstance(r, bool)
1741        self.assertIs(r, True)
1742        pg.set_bytea_escaped(False)
1743        try:
1744            r = pg.get_bytea_escaped()
1745        finally:
1746            pg.set_bytea_escaped(bytea_escaped)
1747        self.assertIsInstance(r, bool)
1748        self.assertIs(r, False)
1749        pg.set_bytea_escaped(1)
1750        try:
1751            r = pg.get_bytea_escaped()
1752        finally:
1753            pg.set_bytea_escaped(bytea_escaped)
1754        self.assertIsInstance(r, bool)
1755        self.assertIs(r, True)
1756        pg.set_bytea_escaped(0)
1757        try:
1758            r = pg.get_bytea_escaped()
1759        finally:
1760            pg.set_bytea_escaped(bytea_escaped)
1761        self.assertIsInstance(r, bool)
1762        self.assertIs(r, False)
1763
1764    def testSetByteaEscaped(self):
1765        bytea_escaped = pg.get_bytea_escaped()
1766        # error if no parameter is passed
1767        self.assertRaises(TypeError, pg.set_bytea_escaped)
1768        query = self.c.query
1769        try:
1770            r = query("select 'data'::bytea")
1771        except pg.ProgrammingError:
1772            self.skipTest('database does not support bytea')
1773        r = r.getresult()[0][0]
1774        self.assertIsInstance(r, bytes)
1775        self.assertEqual(r, b'data')
1776        r = query("select 'data'::bytea")
1777        pg.set_bytea_escaped(True)
1778        try:
1779            r = r.getresult()[0][0]
1780        finally:
1781            pg.set_bytea_escaped(bytea_escaped)
1782        self.assertIsInstance(r, str)
1783        self.assertEqual(r, '\\x64617461')
1784        r = query("select 'data'::bytea")
1785        pg.set_bytea_escaped(False)
1786        try:
1787            r = r.getresult()[0][0]
1788        finally:
1789            pg.set_bytea_escaped(bytea_escaped)
1790        self.assertIsInstance(r, bytes)
1791        self.assertEqual(r, b'data')
1792
1793    def testGetNamedresult(self):
1794        namedresult = pg.get_namedresult()
1795        # error if a parameter is passed
1796        self.assertRaises(TypeError, pg.get_namedresult, namedresult)
1797        self.assertIs(namedresult, pg._namedresult)  # the default setting
1798
1799    def testSetNamedresult(self):
1800        namedresult = pg.get_namedresult()
1801        self.assertTrue(callable(namedresult))
1802
1803        query = self.c.query
1804
1805        r = query("select 1 as x, 2 as y").namedresult()[0]
1806        self.assertIsInstance(r, tuple)
1807        self.assertEqual(r, (1, 2))
1808        self.assertIsNot(type(r), tuple)
1809        self.assertEqual(r._fields, ('x', 'y'))
1810        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1811        self.assertEqual(r.__class__.__name__, 'Row')
1812
1813        def listresult(q):
1814            return [list(row) for row in q.getresult()]
1815
1816        pg.set_namedresult(listresult)
1817        try:
1818            r = pg.get_namedresult()
1819            self.assertIs(r, listresult)
1820            r = query("select 1 as x, 2 as y").namedresult()[0]
1821            self.assertIsInstance(r, list)
1822            self.assertEqual(r, [1, 2])
1823            self.assertIsNot(type(r), tuple)
1824            self.assertFalse(hasattr(r, '_fields'))
1825            self.assertNotEqual(r.__class__.__name__, 'Row')
1826        finally:
1827            pg.set_namedresult(namedresult)
1828
1829        r = pg.get_namedresult()
1830        self.assertIs(r, namedresult)
1831
1832    def testSetRowFactorySize(self):
1833        try:
1834            from functools import lru_cache
1835        except ImportError:  # Python < 3.2
1836            lru_cache = None
1837        queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc']
1838        query = self.c.query
1839        for maxsize in (None, 0, 1, 2, 3, 10, 1024):
1840            pg.set_row_factory_size(maxsize)
1841            for i in range(3):
1842                for q in queries:
1843                    r = query(q).namedresult()[0]
1844                    if q.endswith('abc'):
1845                        self.assertEqual(r, (123,))
1846                        self.assertEqual(r._fields, ('abc',))
1847                    else:
1848                        self.assertEqual(r, (1, 2, 3))
1849                        self.assertEqual(r._fields, ('a', 'b', 'c'))
1850            if lru_cache:
1851                info = pg._row_factory.cache_info()
1852                self.assertEqual(info.maxsize, maxsize)
1853                self.assertEqual(info.hits + info.misses, 6)
1854                self.assertEqual(info.hits,
1855                    0 if maxsize is not None and maxsize < 2 else 4)
1856
1857
1858class TestStandaloneEscapeFunctions(unittest.TestCase):
1859    """Test pg escape functions.
1860
1861    The libpq interface memorizes some parameters of the last opened
1862    connection that influence the result of these functions.  Therefore
1863    we need to open a connection with fixed parameters prior to testing
1864    in order to ensure that the tests always run under the same conditions.
1865    That's why these tests are included in this test module.
1866    """
1867
1868    cls_set_up = False
1869
1870    @classmethod
1871    def setUpClass(cls):
1872        db = connect()
1873        query = db.query
1874        query('set client_encoding=sql_ascii')
1875        query('set standard_conforming_strings=off')
1876        try:
1877            query('set bytea_output=escape')
1878        except pg.ProgrammingError:
1879            if db.server_version >= 90000:
1880                raise  # ignore for older server versions
1881        db.close()
1882        cls.cls_set_up = True
1883
1884    def testEscapeString(self):
1885        self.assertTrue(self.cls_set_up)
1886        f = pg.escape_string
1887        r = f(b'plain')
1888        self.assertIsInstance(r, bytes)
1889        self.assertEqual(r, b'plain')
1890        r = f(u'plain')
1891        self.assertIsInstance(r, unicode)
1892        self.assertEqual(r, u'plain')
1893        r = f(u"das is' kÀse".encode('utf-8'))
1894        self.assertIsInstance(r, bytes)
1895        self.assertEqual(r, u"das is'' kÀse".encode('utf-8'))
1896        r = f(u"that's cheesy")
1897        self.assertIsInstance(r, unicode)
1898        self.assertEqual(r, u"that''s cheesy")
1899        r = f(r"It's bad to have a \ inside.")
1900        self.assertEqual(r, r"It''s bad to have a \\ inside.")
1901
1902    def testEscapeBytea(self):
1903        self.assertTrue(self.cls_set_up)
1904        f = pg.escape_bytea
1905        r = f(b'plain')
1906        self.assertIsInstance(r, bytes)
1907        self.assertEqual(r, b'plain')
1908        r = f(u'plain')
1909        self.assertIsInstance(r, unicode)
1910        self.assertEqual(r, u'plain')
1911        r = f(u"das is' kÀse".encode('utf-8'))
1912        self.assertIsInstance(r, bytes)
1913        self.assertEqual(r, b"das is'' k\\\\303\\\\244se")
1914        r = f(u"that's cheesy")
1915        self.assertIsInstance(r, unicode)
1916        self.assertEqual(r, u"that''s cheesy")
1917        r = f(b'O\x00ps\xff!')
1918        self.assertEqual(r, b'O\\\\000ps\\\\377!')
1919
1920
1921if __name__ == '__main__':
1922    unittest.main()
Note: See TracBrowser for help on using the repository browser.