source: trunk/tests/test_classic_connection.py @ 928

Last change on this file since 928 was 928, checked in by cito, 21 months ago

Adapt tests for Postgres 10

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 69.3 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17import threading
18import time
19import os
20
21from collections import namedtuple
22from decimal import Decimal
23
24import pg  # the module under test
25
26# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
27# get our information from that.  Otherwise we use the defaults.
28# These tests should be run with various PostgreSQL versions and databases
29# created with different encodings and locales.  Particularly, make sure the
30# tests are running against databases created with both SQL_ASCII and UTF8.
31dbname = 'unittest'
32dbhost = None
33dbport = 5432
34
35try:
36    from .LOCAL_PyGreSQL import *
37except (ImportError, ValueError):
38    try:
39        from LOCAL_PyGreSQL import *
40    except ImportError:
41        pass
42
43try:
44    long
45except NameError:  # Python >= 3.0
46    long = int
47
48try:
49    unicode
50except NameError:  # Python >= 3.0
51    unicode = str
52
53unicode_strings = str is not bytes
54
55windows = os.name == 'nt'
56
57# There is a known a bug in libpq under Windows which can cause
58# the interface to crash when calling PQhost():
59do_not_ask_for_host = windows
60do_not_ask_for_host_reason = 'libpq issue on Windows'
61
62
63def connect():
64    """Create a basic pg connection to the test database."""
65    connection = pg.connect(dbname, dbhost, dbport)
66    connection.query("set client_min_messages=warning")
67    return connection
68
69
70class TestCanConnect(unittest.TestCase):
71    """Test whether a basic connection to PostgreSQL is possible."""
72
73    def testCanConnect(self):
74        try:
75            connection = connect()
76        except pg.Error as error:
77            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
78        try:
79            connection.close()
80        except pg.Error:
81            self.fail('Cannot close the database connection')
82
83
84class TestConnectObject(unittest.TestCase):
85    """Test existence of basic pg connection methods."""
86
87    def setUp(self):
88        self.connection = connect()
89
90    def tearDown(self):
91        try:
92            self.connection.close()
93        except pg.InternalError:
94            pass
95
96    def is_method(self, attribute):
97        """Check if given attribute on the connection is a method."""
98        if do_not_ask_for_host and attribute == 'host':
99            return False
100        return callable(getattr(self.connection, attribute))
101
102    def testClassName(self):
103        self.assertEqual(self.connection.__class__.__name__, 'Connection')
104
105    def testModuleName(self):
106        self.assertEqual(self.connection.__class__.__module__, 'pg')
107
108    def testStr(self):
109        r = str(self.connection)
110        self.assertTrue(r.startswith('<pg.Connection object'), r)
111
112    def testRepr(self):
113        r = repr(self.connection)
114        self.assertTrue(r.startswith('<pg.Connection object'), r)
115
116    def testAllConnectAttributes(self):
117        attributes = '''db error host options port
118            protocol_version server_version status user'''.split()
119        connection_attributes = [a for a in dir(self.connection)
120            if not a.startswith('__') and not self.is_method(a)]
121        self.assertEqual(attributes, connection_attributes)
122
123    def testAllConnectMethods(self):
124        methods = '''cancel close date_format endcopy
125            escape_bytea escape_identifier escape_literal escape_string
126            fileno get_cast_hook get_notice_receiver getline getlo getnotify
127            inserttable locreate loimport parameter putline query reset
128            set_cast_hook set_notice_receiver source transaction'''.split()
129        connection_methods = [a for a in dir(self.connection)
130            if not a.startswith('__') and self.is_method(a)]
131        self.assertEqual(methods, connection_methods)
132
133    def testAttributeDb(self):
134        self.assertEqual(self.connection.db, dbname)
135
136    def testAttributeError(self):
137        error = self.connection.error
138        self.assertTrue(not error or 'krb5_' in error)
139
140    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
141    def testAttributeHost(self):
142        if dbhost and not dbhost.startswith('/'):
143            host = dbhost
144        else:
145            host = 'localhost'
146        self.assertIsInstance(self.connection.host, str)
147        self.assertEqual(self.connection.host, host)
148
149    def testAttributeOptions(self):
150        no_options = ''
151        self.assertEqual(self.connection.options, no_options)
152
153    def testAttributePort(self):
154        def_port = 5432
155        self.assertIsInstance(self.connection.port, int)
156        self.assertEqual(self.connection.port, dbport or def_port)
157
158    def testAttributeProtocolVersion(self):
159        protocol_version = self.connection.protocol_version
160        self.assertIsInstance(protocol_version, int)
161        self.assertTrue(2 <= protocol_version < 4)
162
163    def testAttributeServerVersion(self):
164        server_version = self.connection.server_version
165        self.assertIsInstance(server_version, int)
166        self.assertTrue(90000 <= server_version < 110000)
167
168    def testAttributeStatus(self):
169        status_ok = 1
170        self.assertIsInstance(self.connection.status, int)
171        self.assertEqual(self.connection.status, status_ok)
172
173    def testAttributeUser(self):
174        no_user = 'Deprecated facility'
175        user = self.connection.user
176        self.assertTrue(user)
177        self.assertIsInstance(user, str)
178        self.assertNotEqual(user, no_user)
179
180    def testMethodQuery(self):
181        query = self.connection.query
182        query("select 1+1")
183        query("select 1+$1", (1,))
184        query("select 1+$1+$2", (2, 3))
185        query("select 1+$1+$2", [2, 3])
186
187    def testMethodQueryEmpty(self):
188        self.assertRaises(ValueError, self.connection.query, '')
189
190    def testMethodEndcopy(self):
191        try:
192            self.connection.endcopy()
193        except IOError:
194            pass
195
196    def testMethodClose(self):
197        self.connection.close()
198        try:
199            self.connection.reset()
200        except (pg.Error, TypeError):
201            pass
202        else:
203            self.fail('Reset should give an error for a closed connection')
204        self.assertRaises(pg.InternalError, self.connection.close)
205        try:
206            self.connection.query('select 1')
207        except (pg.Error, TypeError):
208            pass
209        else:
210            self.fail('Query should give an error for a closed connection')
211        self.connection = connect()
212
213    def testMethodReset(self):
214        query = self.connection.query
215        # check that client encoding gets reset
216        encoding = query('show client_encoding').getresult()[0][0].upper()
217        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
218        self.assertNotEqual(encoding, changed_encoding)
219        self.connection.query("set client_encoding=%s" % changed_encoding)
220        new_encoding = query('show client_encoding').getresult()[0][0].upper()
221        self.assertEqual(new_encoding, changed_encoding)
222        self.connection.reset()
223        new_encoding = query('show client_encoding').getresult()[0][0].upper()
224        self.assertNotEqual(new_encoding, changed_encoding)
225        self.assertEqual(new_encoding, encoding)
226
227    def testMethodCancel(self):
228        r = self.connection.cancel()
229        self.assertIsInstance(r, int)
230        self.assertEqual(r, 1)
231
232    def testCancelLongRunningThread(self):
233        errors = []
234
235        def sleep():
236            try:
237                self.connection.query('select pg_sleep(5)').getresult()
238            except pg.DatabaseError as error:
239                errors.append(str(error))
240
241        thread = threading.Thread(target=sleep)
242        t1 = time.time()
243        thread.start()  # run the query
244        while 1:  # make sure the query is really running
245            time.sleep(0.1)
246            if thread.is_alive() or time.time() - t1 > 5:
247                break
248        r = self.connection.cancel()  # cancel the running query
249        thread.join()  # wait for the thread to end
250        t2 = time.time()
251
252        self.assertIsInstance(r, int)
253        self.assertEqual(r, 1)  # return code should be 1
254        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
255        self.assertTrue(errors)
256
257    def testMethodFileNo(self):
258        r = self.connection.fileno()
259        self.assertIsInstance(r, int)
260        self.assertGreaterEqual(r, 0)
261
262    def testMethodTransaction(self):
263        transaction = self.connection.transaction
264        self.assertRaises(TypeError, transaction, None)
265        self.assertEqual(transaction(), pg.TRANS_IDLE)
266        self.connection.query('begin')
267        self.assertEqual(transaction(), pg.TRANS_INTRANS)
268        self.connection.query('rollback')
269        self.assertEqual(transaction(), pg.TRANS_IDLE)
270
271    def testMethodParameter(self):
272        parameter = self.connection.parameter
273        query = self.connection.query
274        self.assertRaises(TypeError, parameter)
275        r = parameter('this server setting does not exist')
276        self.assertIsNone(r)
277        s = query('show server_version').getresult()[0][0].upper()
278        self.assertIsNotNone(s)
279        r = parameter('server_version')
280        self.assertEqual(r, s)
281        s = query('show server_encoding').getresult()[0][0].upper()
282        self.assertIsNotNone(s)
283        r = parameter('server_encoding')
284        self.assertEqual(r, s)
285        s = query('show client_encoding').getresult()[0][0].upper()
286        self.assertIsNotNone(s)
287        r = parameter('client_encoding')
288        self.assertEqual(r, s)
289        s = query('show server_encoding').getresult()[0][0].upper()
290        self.assertIsNotNone(s)
291        r = parameter('server_encoding')
292        self.assertEqual(r, s)
293
294
295class TestSimpleQueries(unittest.TestCase):
296    """Test simple queries via a basic pg connection."""
297
298    def setUp(self):
299        self.c = connect()
300
301    def tearDown(self):
302        self.doCleanups()
303        self.c.close()
304
305    def testClassName(self):
306        r = self.c.query("select 1")
307        self.assertEqual(r.__class__.__name__, 'Query')
308
309    def testModuleName(self):
310        r = self.c.query("select 1")
311        self.assertEqual(r.__class__.__module__, 'pg')
312
313    def testStr(self):
314        q = ("select 1 as a, 'hello' as h, 'w' as world"
315            " union select 2, 'xyz', 'uvw'")
316        r = self.c.query(q)
317        self.assertEqual(str(r),
318            'a|  h  |world\n'
319            '-+-----+-----\n'
320            '1|hello|w    \n'
321            '2|xyz  |uvw  \n'
322            '(2 rows)')
323
324    def testRepr(self):
325        r = repr(self.c.query("select 1"))
326        self.assertTrue(r.startswith('<pg.Query object'), r)
327
328    def testSelect0(self):
329        q = "select 0"
330        self.c.query(q)
331
332    def testSelect0Semicolon(self):
333        q = "select 0;"
334        self.c.query(q)
335
336    def testSelectDotSemicolon(self):
337        q = "select .;"
338        self.assertRaises(pg.DatabaseError, self.c.query, q)
339
340    def testGetresult(self):
341        q = "select 0"
342        result = [(0,)]
343        r = self.c.query(q).getresult()
344        self.assertIsInstance(r, list)
345        v = r[0]
346        self.assertIsInstance(v, tuple)
347        self.assertIsInstance(v[0], int)
348        self.assertEqual(r, result)
349
350    def testGetresultLong(self):
351        q = "select 9876543210"
352        result = long(9876543210)
353        self.assertIsInstance(result, long)
354        v = self.c.query(q).getresult()[0][0]
355        self.assertIsInstance(v, long)
356        self.assertEqual(v, result)
357
358    def testGetresultDecimal(self):
359        q = "select 98765432109876543210"
360        result = Decimal(98765432109876543210)
361        v = self.c.query(q).getresult()[0][0]
362        self.assertIsInstance(v, Decimal)
363        self.assertEqual(v, result)
364
365    def testGetresultString(self):
366        result = 'Hello, world!'
367        q = "select '%s'" % result
368        v = self.c.query(q).getresult()[0][0]
369        self.assertIsInstance(v, str)
370        self.assertEqual(v, result)
371
372    def testDictresult(self):
373        q = "select 0 as alias0"
374        result = [{'alias0': 0}]
375        r = self.c.query(q).dictresult()
376        self.assertIsInstance(r, list)
377        v = r[0]
378        self.assertIsInstance(v, dict)
379        self.assertIsInstance(v['alias0'], int)
380        self.assertEqual(r, result)
381
382    def testDictresultLong(self):
383        q = "select 9876543210 as longjohnsilver"
384        result = long(9876543210)
385        self.assertIsInstance(result, long)
386        v = self.c.query(q).dictresult()[0]['longjohnsilver']
387        self.assertIsInstance(v, long)
388        self.assertEqual(v, result)
389
390    def testDictresultDecimal(self):
391        q = "select 98765432109876543210 as longjohnsilver"
392        result = Decimal(98765432109876543210)
393        v = self.c.query(q).dictresult()[0]['longjohnsilver']
394        self.assertIsInstance(v, Decimal)
395        self.assertEqual(v, result)
396
397    def testDictresultString(self):
398        result = 'Hello, world!'
399        q = "select '%s' as greeting" % result
400        v = self.c.query(q).dictresult()[0]['greeting']
401        self.assertIsInstance(v, str)
402        self.assertEqual(v, result)
403
404    def testNamedresult(self):
405        q = "select 0 as alias0"
406        result = [(0,)]
407        r = self.c.query(q).namedresult()
408        self.assertEqual(r, result)
409        v = r[0]
410        self.assertEqual(v._fields, ('alias0',))
411        self.assertEqual(v.alias0, 0)
412
413    def testNamedresultWithGoodFieldnames(self):
414        q = 'select 1 as snake_case_alias, 2 as "CamelCaseAlias"'
415        result = [(1, 2)]
416        r = self.c.query(q).namedresult()
417        self.assertEqual(r, result)
418        v = r[0]
419        self.assertEqual(v._fields, ('snake_case_alias', 'CamelCaseAlias'))
420
421    def testNamedresultWithBadFieldnames(self):
422        try:
423            r = namedtuple('Bad', ['?'] * 6, rename=True)
424        except TypeError:  # Python 2.6 or 3.0
425            fields = tuple('column_%d' % n for n in range(6))
426        else:
427            fields = r._fields
428        q = ('select 3 as "0alias", 4 as _alias, 5 as "alias$", 6 as "alias?",'
429            ' 7 as "kebap-case-alias", 8 as break, 9 as and_a_good_one')
430        result = [tuple(range(3, 10))]
431        r = self.c.query(q).namedresult()
432        self.assertEqual(r, result)
433        v = r[0]
434        self.assertEqual(v._fields[:6], fields)
435        self.assertEqual(v._fields[6], 'and_a_good_one')
436
437    def testGet3Cols(self):
438        q = "select 1,2,3"
439        result = [(1, 2, 3)]
440        r = self.c.query(q).getresult()
441        self.assertEqual(r, result)
442
443    def testGet3DictCols(self):
444        q = "select 1 as a,2 as b,3 as c"
445        result = [dict(a=1, b=2, c=3)]
446        r = self.c.query(q).dictresult()
447        self.assertEqual(r, result)
448
449    def testGet3NamedCols(self):
450        q = "select 1 as a,2 as b,3 as c"
451        result = [(1, 2, 3)]
452        r = self.c.query(q).namedresult()
453        self.assertEqual(r, result)
454        v = r[0]
455        self.assertEqual(v._fields, ('a', 'b', 'c'))
456        self.assertEqual(v.b, 2)
457
458    def testGet3Rows(self):
459        q = "select 3 union select 1 union select 2 order by 1"
460        result = [(1,), (2,), (3,)]
461        r = self.c.query(q).getresult()
462        self.assertEqual(r, result)
463
464    def testGet3DictRows(self):
465        q = ("select 3 as alias3"
466            " union select 1 union select 2 order by 1")
467        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
468        r = self.c.query(q).dictresult()
469        self.assertEqual(r, result)
470
471    def testGet3NamedRows(self):
472        q = ("select 3 as alias3"
473            " union select 1 union select 2 order by 1")
474        result = [(1,), (2,), (3,)]
475        r = self.c.query(q).namedresult()
476        self.assertEqual(r, result)
477        for v in r:
478            self.assertEqual(v._fields, ('alias3',))
479
480    def testDictresultNames(self):
481        q = "select 'MixedCase' as MixedCaseAlias"
482        result = [{'mixedcasealias': 'MixedCase'}]
483        r = self.c.query(q).dictresult()
484        self.assertEqual(r, result)
485        q = "select 'MixedCase' as \"MixedCaseAlias\""
486        result = [{'MixedCaseAlias': 'MixedCase'}]
487        r = self.c.query(q).dictresult()
488        self.assertEqual(r, result)
489
490    def testNamedresultNames(self):
491        q = "select 'MixedCase' as MixedCaseAlias"
492        result = [('MixedCase',)]
493        r = self.c.query(q).namedresult()
494        self.assertEqual(r, result)
495        v = r[0]
496        self.assertEqual(v._fields, ('mixedcasealias',))
497        self.assertEqual(v.mixedcasealias, 'MixedCase')
498        q = "select 'MixedCase' as \"MixedCaseAlias\""
499        r = self.c.query(q).namedresult()
500        self.assertEqual(r, result)
501        v = r[0]
502        self.assertEqual(v._fields, ('MixedCaseAlias',))
503        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
504
505    def testBigGetresult(self):
506        num_cols = 100
507        num_rows = 100
508        q = "select " + ','.join(map(str, range(num_cols)))
509        q = ' union all '.join((q,) * num_rows)
510        r = self.c.query(q).getresult()
511        result = [tuple(range(num_cols))] * num_rows
512        self.assertEqual(r, result)
513
514    def testListfields(self):
515        q = ('select 0 as a, 0 as b, 0 as c,'
516            ' 0 as c, 0 as b, 0 as a,'
517            ' 0 as lowercase, 0 as UPPERCASE,'
518            ' 0 as MixedCase, 0 as "MixedCase",'
519            ' 0 as a_long_name_with_underscores,'
520            ' 0 as "A long name with Blanks"')
521        r = self.c.query(q).listfields()
522        self.assertIsInstance(r, tuple)
523        result = ('a', 'b', 'c', 'c', 'b', 'a',
524            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
525            'a_long_name_with_underscores',
526            'A long name with Blanks')
527        self.assertEqual(r, result)
528
529    def testFieldname(self):
530        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
531        r = self.c.query(q).fieldname(2)
532        self.assertEqual(r, 'x')
533        r = self.c.query(q).fieldname(3)
534        self.assertEqual(r, 'y')
535
536    def testFieldnum(self):
537        q = "select 1 as x"
538        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
539        q = "select 1 as x"
540        r = self.c.query(q).fieldnum('x')
541        self.assertIsInstance(r, int)
542        self.assertEqual(r, 0)
543        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
544        r = self.c.query(q).fieldnum('x')
545        self.assertIsInstance(r, int)
546        self.assertEqual(r, 2)
547        r = self.c.query(q).fieldnum('y')
548        self.assertIsInstance(r, int)
549        self.assertEqual(r, 3)
550
551    def testNtuples(self):
552        q = "select 1 where false"
553        r = self.c.query(q).ntuples()
554        self.assertIsInstance(r, int)
555        self.assertEqual(r, 0)
556        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
557            " union select 5 as a, 6 as b, 7 as c, 8 as d")
558        r = self.c.query(q).ntuples()
559        self.assertIsInstance(r, int)
560        self.assertEqual(r, 2)
561        q = ("select 1 union select 2 union select 3"
562            " union select 4 union select 5 union select 6")
563        r = self.c.query(q).ntuples()
564        self.assertIsInstance(r, int)
565        self.assertEqual(r, 6)
566
567    def testQuery(self):
568        query = self.c.query
569        query("drop table if exists test_table")
570        self.addCleanup(query, "drop table test_table")
571        q = "create table test_table (n integer) with oids"
572        r = query(q)
573        self.assertIsNone(r)
574        q = "insert into test_table values (1)"
575        r = query(q)
576        self.assertIsInstance(r, int)
577        q = "insert into test_table select 2"
578        r = query(q)
579        self.assertIsInstance(r, int)
580        oid = r
581        q = "select oid from test_table where n=2"
582        r = query(q).getresult()
583        self.assertEqual(len(r), 1)
584        r = r[0]
585        self.assertEqual(len(r), 1)
586        r = r[0]
587        self.assertIsInstance(r, int)
588        self.assertEqual(r, oid)
589        q = "insert into test_table select 3 union select 4 union select 5"
590        r = query(q)
591        self.assertIsInstance(r, str)
592        self.assertEqual(r, '3')
593        q = "update test_table set n=4 where n<5"
594        r = query(q)
595        self.assertIsInstance(r, str)
596        self.assertEqual(r, '4')
597        q = "delete from test_table"
598        r = query(q)
599        self.assertIsInstance(r, str)
600        self.assertEqual(r, '5')
601
602
603class TestUnicodeQueries(unittest.TestCase):
604    """Test unicode strings as queries via a basic pg connection."""
605
606    def setUp(self):
607        self.c = connect()
608        self.c.query('set client_encoding=utf8')
609
610    def tearDown(self):
611        self.c.close()
612
613    def testGetresulAscii(self):
614        result = u'Hello, world!'
615        q = u"select '%s'" % result
616        v = self.c.query(q).getresult()[0][0]
617        self.assertIsInstance(v, str)
618        self.assertEqual(v, result)
619
620    def testDictresulAscii(self):
621        result = u'Hello, world!'
622        q = u"select '%s' as greeting" % result
623        v = self.c.query(q).dictresult()[0]['greeting']
624        self.assertIsInstance(v, str)
625        self.assertEqual(v, result)
626
627    def testGetresultUtf8(self):
628        result = u'Hello, wörld & ЌОр!'
629        q = u"select '%s'" % result
630        if not unicode_strings:
631            result = result.encode('utf8')
632        # pass the query as unicode
633        try:
634            v = self.c.query(q).getresult()[0][0]
635        except(pg.DataError, pg.NotSupportedError):
636            self.skipTest("database does not support utf8")
637        self.assertIsInstance(v, str)
638        self.assertEqual(v, result)
639        q = q.encode('utf8')
640        # pass the query as bytes
641        v = self.c.query(q).getresult()[0][0]
642        self.assertIsInstance(v, str)
643        self.assertEqual(v, result)
644
645    def testDictresultUtf8(self):
646        result = u'Hello, wörld & ЌОр!'
647        q = u"select '%s' as greeting" % result
648        if not unicode_strings:
649            result = result.encode('utf8')
650        try:
651            v = self.c.query(q).dictresult()[0]['greeting']
652        except (pg.DataError, pg.NotSupportedError):
653            self.skipTest("database does not support utf8")
654        self.assertIsInstance(v, str)
655        self.assertEqual(v, result)
656        q = q.encode('utf8')
657        v = self.c.query(q).dictresult()[0]['greeting']
658        self.assertIsInstance(v, str)
659        self.assertEqual(v, result)
660
661    def testDictresultLatin1(self):
662        try:
663            self.c.query('set client_encoding=latin1')
664        except (pg.DataError, pg.NotSupportedError):
665            self.skipTest("database does not support latin1")
666        result = u'Hello, wörld!'
667        q = u"select '%s'" % result
668        if not unicode_strings:
669            result = result.encode('latin1')
670        v = self.c.query(q).getresult()[0][0]
671        self.assertIsInstance(v, str)
672        self.assertEqual(v, result)
673        q = q.encode('latin1')
674        v = self.c.query(q).getresult()[0][0]
675        self.assertIsInstance(v, str)
676        self.assertEqual(v, result)
677
678    def testDictresultLatin1(self):
679        try:
680            self.c.query('set client_encoding=latin1')
681        except (pg.DataError, pg.NotSupportedError):
682            self.skipTest("database does not support latin1")
683        result = u'Hello, wörld!'
684        q = u"select '%s' as greeting" % result
685        if not unicode_strings:
686            result = result.encode('latin1')
687        v = self.c.query(q).dictresult()[0]['greeting']
688        self.assertIsInstance(v, str)
689        self.assertEqual(v, result)
690        q = q.encode('latin1')
691        v = self.c.query(q).dictresult()[0]['greeting']
692        self.assertIsInstance(v, str)
693        self.assertEqual(v, result)
694
695    def testGetresultCyrillic(self):
696        try:
697            self.c.query('set client_encoding=iso_8859_5')
698        except (pg.DataError, pg.NotSupportedError):
699            self.skipTest("database does not support cyrillic")
700        result = u'Hello, ЌОр!'
701        q = u"select '%s'" % result
702        if not unicode_strings:
703            result = result.encode('cyrillic')
704        v = self.c.query(q).getresult()[0][0]
705        self.assertIsInstance(v, str)
706        self.assertEqual(v, result)
707        q = q.encode('cyrillic')
708        v = self.c.query(q).getresult()[0][0]
709        self.assertIsInstance(v, str)
710        self.assertEqual(v, result)
711
712    def testDictresultCyrillic(self):
713        try:
714            self.c.query('set client_encoding=iso_8859_5')
715        except (pg.DataError, pg.NotSupportedError):
716            self.skipTest("database does not support cyrillic")
717        result = u'Hello, ЌОр!'
718        q = u"select '%s' as greeting" % result
719        if not unicode_strings:
720            result = result.encode('cyrillic')
721        v = self.c.query(q).dictresult()[0]['greeting']
722        self.assertIsInstance(v, str)
723        self.assertEqual(v, result)
724        q = q.encode('cyrillic')
725        v = self.c.query(q).dictresult()[0]['greeting']
726        self.assertIsInstance(v, str)
727        self.assertEqual(v, result)
728
729    def testGetresultLatin9(self):
730        try:
731            self.c.query('set client_encoding=latin9')
732        except (pg.DataError, pg.NotSupportedError):
733            self.skipTest("database does not support latin9")
734        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
735        q = u"select '%s'" % result
736        if not unicode_strings:
737            result = result.encode('latin9')
738        v = self.c.query(q).getresult()[0][0]
739        self.assertIsInstance(v, str)
740        self.assertEqual(v, result)
741        q = q.encode('latin9')
742        v = self.c.query(q).getresult()[0][0]
743        self.assertIsInstance(v, str)
744        self.assertEqual(v, result)
745
746    def testDictresultLatin9(self):
747        try:
748            self.c.query('set client_encoding=latin9')
749        except (pg.DataError, pg.NotSupportedError):
750            self.skipTest("database does not support latin9")
751        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
752        q = u"select '%s' as menu" % result
753        if not unicode_strings:
754            result = result.encode('latin9')
755        v = self.c.query(q).dictresult()[0]['menu']
756        self.assertIsInstance(v, str)
757        self.assertEqual(v, result)
758        q = q.encode('latin9')
759        v = self.c.query(q).dictresult()[0]['menu']
760        self.assertIsInstance(v, str)
761        self.assertEqual(v, result)
762
763
764class TestParamQueries(unittest.TestCase):
765    """Test queries with parameters via a basic pg connection."""
766
767    def setUp(self):
768        self.c = connect()
769        self.c.query('set client_encoding=utf8')
770
771    def tearDown(self):
772        self.c.close()
773
774    def testQueryWithNoneParam(self):
775        self.assertRaises(TypeError, self.c.query, "select $1", None)
776        self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None)
777        self.assertEqual(self.c.query("select $1::integer", (None,)
778            ).getresult(), [(None,)])
779        self.assertEqual(self.c.query("select $1::text", [None]
780            ).getresult(), [(None,)])
781        self.assertEqual(self.c.query("select $1::text", [[None]]
782            ).getresult(), [(None,)])
783
784    def testQueryWithBoolParams(self, bool_enabled=None):
785        query = self.c.query
786        if bool_enabled is not None:
787            bool_enabled_default = pg.get_bool()
788            pg.set_bool(bool_enabled)
789        try:
790            bool_on = bool_enabled or bool_enabled is None
791            v_false, v_true = (False, True) if bool_on else 'ft'
792            r_false, r_true = [(v_false,)], [(v_true,)]
793            self.assertEqual(query("select false").getresult(), r_false)
794            self.assertEqual(query("select true").getresult(), r_true)
795            q = "select $1::bool"
796            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
797            self.assertEqual(query(q, ('f',)).getresult(), r_false)
798            self.assertEqual(query(q, ('t',)).getresult(), r_true)
799            self.assertEqual(query(q, ('false',)).getresult(), r_false)
800            self.assertEqual(query(q, ('true',)).getresult(), r_true)
801            self.assertEqual(query(q, ('n',)).getresult(), r_false)
802            self.assertEqual(query(q, ('y',)).getresult(), r_true)
803            self.assertEqual(query(q, (0,)).getresult(), r_false)
804            self.assertEqual(query(q, (1,)).getresult(), r_true)
805            self.assertEqual(query(q, (False,)).getresult(), r_false)
806            self.assertEqual(query(q, (True,)).getresult(), r_true)
807        finally:
808            if bool_enabled is not None:
809                pg.set_bool(bool_enabled_default)
810
811    def testQueryWithBoolParamsNotDefault(self):
812        self.testQueryWithBoolParams(bool_enabled=not pg.get_bool())
813
814    def testQueryWithIntParams(self):
815        query = self.c.query
816        self.assertEqual(query("select 1+1").getresult(), [(2,)])
817        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
818        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
819        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
820        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
821        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
822            [(Decimal('2'),)])
823        self.assertEqual(query("select 1, $1::integer", (2,)
824            ).getresult(), [(1, 2)])
825        self.assertEqual(query("select 1 union select $1::integer", (2,)
826            ).getresult(), [(1,), (2,)])
827        self.assertEqual(query("select $1::integer+$2", (1, 2)
828            ).getresult(), [(3,)])
829        self.assertEqual(query("select $1::integer+$2", [1, 2]
830            ).getresult(), [(3,)])
831        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))
832            ).getresult(), [(15,)])
833
834    def testQueryWithStrParams(self):
835        query = self.c.query
836        self.assertEqual(query("select $1||', world!'", ('Hello',)
837            ).getresult(), [('Hello, world!',)])
838        self.assertEqual(query("select $1||', world!'", ['Hello']
839            ).getresult(), [('Hello, world!',)])
840        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
841            ).getresult(), [('Hello, world!',)])
842        self.assertEqual(query("select $1::text", ('Hello, world!',)
843            ).getresult(), [('Hello, world!',)])
844        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
845            ).getresult(), [('Hello', 'world')])
846        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
847            ).getresult(), [('Hello', 'world')])
848        self.assertEqual(query("select $1::text union select $2::text",
849            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
850        try:
851            query("select 'wörld'")
852        except (pg.DataError, pg.NotSupportedError):
853            self.skipTest('database does not support utf8')
854        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
855            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
856
857    def testQueryWithUnicodeParams(self):
858        query = self.c.query
859        try:
860            query('set client_encoding=utf8')
861            query("select 'wörld'").getresult()[0][0] == 'wörld'
862        except (pg.DataError, pg.NotSupportedError):
863            self.skipTest("database does not support utf8")
864        self.assertEqual(query("select $1||', '||$2||'!'",
865            ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)])
866
867    def testQueryWithUnicodeParamsLatin1(self):
868        query = self.c.query
869        try:
870            query('set client_encoding=latin1')
871            query("select 'wörld'").getresult()[0][0] == 'wörld'
872        except (pg.DataError, pg.NotSupportedError):
873            self.skipTest("database does not support latin1")
874        r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult()
875        if unicode_strings:
876            self.assertEqual(r, [('Hello, wörld!',)])
877        else:
878            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
879        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
880            ('Hello', u'ЌОр'))
881        query('set client_encoding=iso_8859_1')
882        r = query("select $1||', '||$2||'!'",
883            ('Hello', u'wörld')).getresult()
884        if unicode_strings:
885            self.assertEqual(r, [('Hello, wörld!',)])
886        else:
887            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
888        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
889            ('Hello', u'ЌОр'))
890        query('set client_encoding=sql_ascii')
891        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
892            ('Hello', u'wörld'))
893
894    def testQueryWithUnicodeParamsCyrillic(self):
895        query = self.c.query
896        try:
897            query('set client_encoding=iso_8859_5')
898            query("select 'ЌОр'").getresult()[0][0] == 'ЌОр'
899        except (pg.DataError, pg.NotSupportedError):
900            self.skipTest("database does not support cyrillic")
901        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
902            ('Hello', u'wörld'))
903        r = query("select $1||', '||$2||'!'",
904            ('Hello', u'ЌОр')).getresult()
905        if unicode_strings:
906            self.assertEqual(r, [('Hello, ЌОр!',)])
907        else:
908            self.assertEqual(r, [(u'Hello, ЌОр!'.encode('cyrillic'),)])
909        query('set client_encoding=sql_ascii')
910        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
911            ('Hello', u'ЌОр!'))
912
913    def testQueryWithMixedParams(self):
914        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
915            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
916        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
917            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
918
919    def testQueryWithDuplicateParams(self):
920        self.assertRaises(pg.ProgrammingError,
921            self.c.query, "select $1+$1", (1,))
922        self.assertRaises(pg.ProgrammingError,
923            self.c.query, "select $1+$1", (1, 2))
924
925    def testQueryWithZeroParams(self):
926        self.assertEqual(self.c.query("select 1+1", []
927            ).getresult(), [(2,)])
928
929    def testQueryWithGarbage(self):
930        garbage = r"'\{}+()-#[]oo324"
931        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
932            ).dictresult(), [{'garbage': garbage}])
933
934
935class TestQueryResultTypes(unittest.TestCase):
936    """Test proper result types via a basic pg connection."""
937
938    def setUp(self):
939        self.c = connect()
940        self.c.query('set client_encoding=utf8')
941        self.c.query("set datestyle='ISO,YMD'")
942        self.c.query("set timezone='UTC'")
943
944    def tearDown(self):
945        self.c.close()
946
947    def assert_proper_cast(self, value, pgtype, pytype):
948        q = 'select $1::%s' % (pgtype,)
949        try:
950            r = self.c.query(q, (value,)).getresult()[0][0]
951        except pg.ProgrammingError:
952            if pgtype in ('json', 'jsonb'):
953                self.skipTest('database does not support json')
954        self.assertIsInstance(r, pytype)
955        if isinstance(value, str):
956            if not value or ' ' in value or '{' in value:
957                value = '"%s"' % value
958        value = '{%s}' % value
959        r = self.c.query(q + '[]', (value,)).getresult()[0][0]
960        if pgtype.startswith(('date', 'time', 'interval')):
961            # arrays of these are casted by the DB wrapper only
962            self.assertEqual(r, value)
963        else:
964            self.assertIsInstance(r, list)
965            self.assertEqual(len(r), 1)
966            self.assertIsInstance(r[0], pytype)
967
968    def testInt(self):
969        self.assert_proper_cast(0, 'int', int)
970        self.assert_proper_cast(0, 'smallint', int)
971        self.assert_proper_cast(0, 'oid', int)
972        self.assert_proper_cast(0, 'cid', int)
973        self.assert_proper_cast(0, 'xid', int)
974
975    def testLong(self):
976        self.assert_proper_cast(0, 'bigint', long)
977
978    def testFloat(self):
979        self.assert_proper_cast(0, 'float', float)
980        self.assert_proper_cast(0, 'real', float)
981        self.assert_proper_cast(0, 'double', float)
982        self.assert_proper_cast(0, 'double precision', float)
983        self.assert_proper_cast('infinity', 'float', float)
984
985    def testFloat(self):
986        decimal = pg.get_decimal()
987        self.assert_proper_cast(decimal(0), 'numeric', decimal)
988        self.assert_proper_cast(decimal(0), 'decimal', decimal)
989
990    def testMoney(self):
991        decimal = pg.get_decimal()
992        self.assert_proper_cast(decimal('0'), 'money', decimal)
993
994    def testBool(self):
995        bool_type = bool if pg.get_bool() else str
996        self.assert_proper_cast('f', 'bool', bool_type)
997
998    def testDate(self):
999        self.assert_proper_cast('1956-01-31', 'date', str)
1000        self.assert_proper_cast('10:20:30', 'interval', str)
1001        self.assert_proper_cast('08:42:15', 'time', str)
1002        self.assert_proper_cast('08:42:15+00', 'timetz', str)
1003        self.assert_proper_cast('1956-01-31 08:42:15', 'timestamp', str)
1004        self.assert_proper_cast('1956-01-31 08:42:15+00', 'timestamptz', str)
1005
1006    def testText(self):
1007        self.assert_proper_cast('', 'text', str)
1008        self.assert_proper_cast('', 'char', str)
1009        self.assert_proper_cast('', 'bpchar', str)
1010        self.assert_proper_cast('', 'varchar', str)
1011
1012    def testBytea(self):
1013        self.assert_proper_cast('', 'bytea', bytes)
1014
1015    def testJson(self):
1016        self.assert_proper_cast('{}', 'json', dict)
1017
1018
1019class TestInserttable(unittest.TestCase):
1020    """Test inserttable method."""
1021
1022    cls_set_up = False
1023
1024    @classmethod
1025    def setUpClass(cls):
1026        c = connect()
1027        c.query("drop table if exists test cascade")
1028        c.query("create table test ("
1029            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
1030            "d numeric, f4 real, f8 double precision, m money,"
1031            "c char(1), v4 varchar(4), c4 char(4), t text)")
1032        # Check whether the test database uses SQL_ASCII - this means
1033        # that it does not consider encoding when calculating lengths.
1034        c.query("set client_encoding=utf8")
1035        try:
1036            c.query("select 'À'")
1037        except (pg.DataError, pg.NotSupportedError):
1038            cls.has_encoding = False
1039        else:
1040            cls.has_encoding = c.query(
1041                "select length('À') - length('a')").getresult()[0][0] == 0
1042        c.close()
1043        cls.cls_set_up = True
1044
1045    @classmethod
1046    def tearDownClass(cls):
1047        c = connect()
1048        c.query("drop table test cascade")
1049        c.close()
1050
1051    def setUp(self):
1052        self.assertTrue(self.cls_set_up)
1053        self.c = connect()
1054        self.c.query("set client_encoding=utf8")
1055        self.c.query("set datestyle='ISO,YMD'")
1056        self.c.query("set lc_monetary='C'")
1057
1058    def tearDown(self):
1059        self.c.query("truncate table test")
1060        self.c.close()
1061
1062    data = [
1063        (-1, -1, long(-1), True, '1492-10-12', '08:30:00',
1064            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
1065        (0, 0, long(0), False, '1607-04-14', '09:00:00',
1066            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
1067        (1, 1, long(1), True, '1801-03-04', '03:45:00',
1068            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
1069        (2, 2, long(2), False, '1903-12-17', '11:22:00',
1070            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
1071
1072    @classmethod
1073    def db_len(cls, s, encoding):
1074        if cls.has_encoding:
1075            s = s if isinstance(s, unicode) else s.decode(encoding)
1076        else:
1077            s = s.encode(encoding) if isinstance(s, unicode) else s
1078        return len(s)
1079
1080    def get_back(self, encoding='utf-8'):
1081        """Convert boolean and decimal values back."""
1082        data = []
1083        for row in self.c.query("select * from test order by 1").getresult():
1084            self.assertIsInstance(row, tuple)
1085            row = list(row)
1086            if row[0] is not None:  # smallint
1087                self.assertIsInstance(row[0], int)
1088            if row[1] is not None:  # integer
1089                self.assertIsInstance(row[1], int)
1090            if row[2] is not None:  # bigint
1091                self.assertIsInstance(row[2], long)
1092            if row[3] is not None:  # boolean
1093                self.assertIsInstance(row[3], bool)
1094            if row[4] is not None:  # date
1095                self.assertIsInstance(row[4], str)
1096                self.assertTrue(row[4].replace('-', '').isdigit())
1097            if row[5] is not None:  # time
1098                self.assertIsInstance(row[5], str)
1099                self.assertTrue(row[5].replace(':', '').isdigit())
1100            if row[6] is not None:  # numeric
1101                self.assertIsInstance(row[6], Decimal)
1102                row[6] = float(row[6])
1103            if row[7] is not None:  # real
1104                self.assertIsInstance(row[7], float)
1105            if row[8] is not None:  # double precision
1106                self.assertIsInstance(row[8], float)
1107                row[8] = float(row[8])
1108            if row[9] is not None:  # money
1109                self.assertIsInstance(row[9], Decimal)
1110                row[9] = str(float(row[9]))
1111            if row[10] is not None:  # char(1)
1112                self.assertIsInstance(row[10], str)
1113                self.assertEqual(self.db_len(row[10], encoding), 1)
1114            if row[11] is not None:  # varchar(4)
1115                self.assertIsInstance(row[11], str)
1116                self.assertLessEqual(self.db_len(row[11], encoding), 4)
1117            if row[12] is not None:  # char(4)
1118                self.assertIsInstance(row[12], str)
1119                self.assertEqual(self.db_len(row[12], encoding), 4)
1120                row[12] = row[12].rstrip()
1121            if row[13] is not None:  # text
1122                self.assertIsInstance(row[13], str)
1123            row = tuple(row)
1124            data.append(row)
1125        return data
1126
1127    def testInserttable1Row(self):
1128        data = self.data[2:3]
1129        self.c.inserttable('test', data)
1130        self.assertEqual(self.get_back(), data)
1131
1132    def testInserttable4Rows(self):
1133        data = self.data
1134        self.c.inserttable('test', data)
1135        self.assertEqual(self.get_back(), data)
1136
1137    def testInserttableMultipleRows(self):
1138        num_rows = 100
1139        data = self.data[2:3] * num_rows
1140        self.c.inserttable('test', data)
1141        r = self.c.query("select count(*) from test").getresult()[0][0]
1142        self.assertEqual(r, num_rows)
1143
1144    def testInserttableMultipleCalls(self):
1145        num_rows = 10
1146        data = self.data[2:3]
1147        for _i in range(num_rows):
1148            self.c.inserttable('test', data)
1149        r = self.c.query("select count(*) from test").getresult()[0][0]
1150        self.assertEqual(r, num_rows)
1151
1152    def testInserttableNullValues(self):
1153        data = [(None,) * 14] * 100
1154        self.c.inserttable('test', data)
1155        self.assertEqual(self.get_back(), data)
1156
1157    def testInserttableMaxValues(self):
1158        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
1159            True, '2999-12-31', '11:59:59', 1e99,
1160            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
1161            "1", "1234", "1234", "1234" * 100)]
1162        self.c.inserttable('test', data)
1163        self.assertEqual(self.get_back(), data)
1164
1165    def testInserttableByteValues(self):
1166        try:
1167            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1168        except pg.DataError:
1169            self.skipTest("database does not support utf8")
1170        # non-ascii chars do not fit in char(1) when there is no encoding
1171        c = u'€' if self.has_encoding else u'$'
1172        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1173            0.0, 0.0, 0.0, u'0.0',
1174            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1175        row_bytes = tuple(s.encode('utf-8')
1176            if isinstance(s, unicode) else s for s in row_unicode)
1177        data = [row_bytes] * 2
1178        self.c.inserttable('test', data)
1179        if unicode_strings:
1180            data = [row_unicode] * 2
1181        self.assertEqual(self.get_back(), data)
1182
1183    def testInserttableUnicodeUtf8(self):
1184        try:
1185            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1186        except pg.DataError:
1187            self.skipTest("database does not support utf8")
1188        # non-ascii chars do not fit in char(1) when there is no encoding
1189        c = u'€' if self.has_encoding else u'$'
1190        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1191            0.0, 0.0, 0.0, u'0.0',
1192            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1193        data = [row_unicode] * 2
1194        self.c.inserttable('test', data)
1195        if not unicode_strings:
1196            row_bytes = tuple(s.encode('utf-8')
1197                if isinstance(s, unicode) else s for s in row_unicode)
1198            data = [row_bytes] * 2
1199        self.assertEqual(self.get_back(), data)
1200
1201    def testInserttableUnicodeLatin1(self):
1202
1203        try:
1204            self.c.query("set client_encoding=latin1")
1205            self.c.query("select 'Â¥'")
1206        except (pg.DataError, pg.NotSupportedError):
1207            self.skipTest("database does not support latin1")
1208        # non-ascii chars do not fit in char(1) when there is no encoding
1209        c = u'€' if self.has_encoding else u'$'
1210        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1211            0.0, 0.0, 0.0, u'0.0',
1212            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1213        data = [row_unicode]
1214        # cannot encode € sign with latin1 encoding
1215        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1216        row_unicode = tuple(s.replace(u'€', u'Â¥')
1217            if isinstance(s, unicode) else s for s in row_unicode)
1218        data = [row_unicode] * 2
1219        self.c.inserttable('test', data)
1220        if not unicode_strings:
1221            row_bytes = tuple(s.encode('latin1')
1222                if isinstance(s, unicode) else s for s in row_unicode)
1223            data = [row_bytes] * 2
1224        self.assertEqual(self.get_back('latin1'), data)
1225
1226    def testInserttableUnicodeLatin9(self):
1227        try:
1228            self.c.query("set client_encoding=latin9")
1229            self.c.query("select '€'")
1230        except (pg.DataError, pg.NotSupportedError):
1231            self.skipTest("database does not support latin9")
1232            return
1233        # non-ascii chars do not fit in char(1) when there is no encoding
1234        c = u'€' if self.has_encoding else u'$'
1235        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1236            0.0, 0.0, 0.0, u'0.0',
1237            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1238        data = [row_unicode] * 2
1239        self.c.inserttable('test', data)
1240        if not unicode_strings:
1241            row_bytes = tuple(s.encode('latin9')
1242                if isinstance(s, unicode) else s for s in row_unicode)
1243            data = [row_bytes] * 2
1244        self.assertEqual(self.get_back('latin9'), data)
1245
1246    def testInserttableNoEncoding(self):
1247        self.c.query("set client_encoding=sql_ascii")
1248        # non-ascii chars do not fit in char(1) when there is no encoding
1249        c = u'€' if self.has_encoding else u'$'
1250        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1251            0.0, 0.0, 0.0, u'0.0',
1252            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1253        data = [row_unicode]
1254        # cannot encode non-ascii unicode without a specific encoding
1255        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1256
1257
1258class TestDirectSocketAccess(unittest.TestCase):
1259    """Test copy command with direct socket access."""
1260
1261    cls_set_up = False
1262
1263    @classmethod
1264    def setUpClass(cls):
1265        c = connect()
1266        c.query("drop table if exists test cascade")
1267        c.query("create table test (i int, v varchar(16))")
1268        c.close()
1269        cls.cls_set_up = True
1270
1271    @classmethod
1272    def tearDownClass(cls):
1273        c = connect()
1274        c.query("drop table test cascade")
1275        c.close()
1276
1277    def setUp(self):
1278        self.assertTrue(self.cls_set_up)
1279        self.c = connect()
1280        self.c.query("set client_encoding=utf8")
1281
1282    def tearDown(self):
1283        self.c.query("truncate table test")
1284        self.c.close()
1285
1286    def testPutline(self):
1287        putline = self.c.putline
1288        query = self.c.query
1289        data = list(enumerate("apple pear plum cherry banana".split()))
1290        query("copy test from stdin")
1291        try:
1292            for i, v in data:
1293                putline("%d\t%s\n" % (i, v))
1294            putline("\\.\n")
1295        finally:
1296            self.c.endcopy()
1297        r = query("select * from test").getresult()
1298        self.assertEqual(r, data)
1299
1300    def testPutlineBytesAndUnicode(self):
1301        putline = self.c.putline
1302        query = self.c.query
1303        try:
1304            query("select 'kÀse+wÃŒrstel'")
1305        except (pg.DataError, pg.NotSupportedError):
1306            self.skipTest('database does not support utf8')
1307        query("copy test from stdin")
1308        try:
1309            putline(u"47\tkÀse\n".encode('utf8'))
1310            putline("35\twÃŒrstel\n")
1311            putline(b"\\.\n")
1312        finally:
1313            self.c.endcopy()
1314        r = query("select * from test").getresult()
1315        self.assertEqual(r, [(47, 'kÀse'), (35, 'wÃŒrstel')])
1316
1317    def testGetline(self):
1318        getline = self.c.getline
1319        query = self.c.query
1320        data = list(enumerate("apple banana pear plum strawberry".split()))
1321        n = len(data)
1322        self.c.inserttable('test', data)
1323        query("copy test to stdout")
1324        try:
1325            for i in range(n + 2):
1326                v = getline()
1327                if i < n:
1328                    self.assertEqual(v, '%d\t%s' % data[i])
1329                elif i == n:
1330                    self.assertEqual(v, '\\.')
1331                else:
1332                    self.assertIsNone(v)
1333        finally:
1334            try:
1335                self.c.endcopy()
1336            except IOError:
1337                pass
1338
1339    def testGetlineBytesAndUnicode(self):
1340        getline = self.c.getline
1341        query = self.c.query
1342        try:
1343            query("select 'kÀse+wÃŒrstel'")
1344        except (pg.DataError, pg.NotSupportedError):
1345            self.skipTest('database does not support utf8')
1346        data = [(54, u'kÀse'.encode('utf8')), (73, u'wÃŒrstel')]
1347        self.c.inserttable('test', data)
1348        query("copy test to stdout")
1349        try:
1350            v = getline()
1351            self.assertIsInstance(v, str)
1352            self.assertEqual(v, '54\tkÀse')
1353            v = getline()
1354            self.assertIsInstance(v, str)
1355            self.assertEqual(v, '73\twÃŒrstel')
1356            self.assertEqual(getline(), '\\.')
1357            self.assertIsNone(getline())
1358        finally:
1359            try:
1360                self.c.endcopy()
1361            except IOError:
1362                pass
1363
1364    def testParameterChecks(self):
1365        self.assertRaises(TypeError, self.c.putline)
1366        self.assertRaises(TypeError, self.c.getline, 'invalid')
1367        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
1368
1369
1370class TestNotificatons(unittest.TestCase):
1371    """Test notification support."""
1372
1373    def setUp(self):
1374        self.c = connect()
1375
1376    def tearDown(self):
1377        self.doCleanups()
1378        self.c.close()
1379
1380    def testGetNotify(self):
1381        getnotify = self.c.getnotify
1382        query = self.c.query
1383        self.assertIsNone(getnotify())
1384        query('listen test_notify')
1385        try:
1386            self.assertIsNone(self.c.getnotify())
1387            query("notify test_notify")
1388            r = getnotify()
1389            self.assertIsInstance(r, tuple)
1390            self.assertEqual(len(r), 3)
1391            self.assertIsInstance(r[0], str)
1392            self.assertIsInstance(r[1], int)
1393            self.assertIsInstance(r[2], str)
1394            self.assertEqual(r[0], 'test_notify')
1395            self.assertEqual(r[2], '')
1396            self.assertIsNone(self.c.getnotify())
1397            query("notify test_notify, 'test_payload'")
1398            r = getnotify()
1399            self.assertTrue(isinstance(r, tuple))
1400            self.assertEqual(len(r), 3)
1401            self.assertIsInstance(r[0], str)
1402            self.assertIsInstance(r[1], int)
1403            self.assertIsInstance(r[2], str)
1404            self.assertEqual(r[0], 'test_notify')
1405            self.assertEqual(r[2], 'test_payload')
1406            self.assertIsNone(getnotify())
1407        finally:
1408            query('unlisten test_notify')
1409
1410    def testGetNoticeReceiver(self):
1411        self.assertIsNone(self.c.get_notice_receiver())
1412
1413    def testSetNoticeReceiver(self):
1414        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
1415        self.assertRaises(TypeError, self.c.set_notice_receiver, 'invalid')
1416        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
1417        self.assertIsNone(self.c.set_notice_receiver(None))
1418
1419    def testSetAndGetNoticeReceiver(self):
1420        r = lambda notice: None
1421        self.assertIsNone(self.c.set_notice_receiver(r))
1422        self.assertIs(self.c.get_notice_receiver(), r)
1423        self.assertIsNone(self.c.set_notice_receiver(None))
1424        self.assertIsNone(self.c.get_notice_receiver())
1425
1426    def testNoticeReceiver(self):
1427        self.addCleanup(self.c.query, 'drop function bilbo_notice();')
1428        self.c.query('''create function bilbo_notice() returns void AS $$
1429            begin
1430                raise warning 'Bilbo was here!';
1431            end;
1432            $$ language plpgsql''')
1433        received = {}
1434
1435        def notice_receiver(notice):
1436            for attr in dir(notice):
1437                if attr.startswith('__'):
1438                    continue
1439                value = getattr(notice, attr)
1440                if isinstance(value, str):
1441                    value = value.replace('WARNUNG', 'WARNING')
1442                received[attr] = value
1443
1444        self.c.set_notice_receiver(notice_receiver)
1445        self.c.query('select bilbo_notice()')
1446        self.assertEqual(received, dict(
1447            pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
1448            severity='WARNING', primary='Bilbo was here!',
1449            detail=None, hint=None))
1450
1451
1452class TestConfigFunctions(unittest.TestCase):
1453    """Test the functions for changing default settings.
1454
1455    To test the effect of most of these functions, we need a database
1456    connection.  That's why they are covered in this test module.
1457    """
1458
1459    def setUp(self):
1460        self.c = connect()
1461        self.c.query("set client_encoding=utf8")
1462        self.c.query('set bytea_output=hex')
1463        self.c.query("set lc_monetary='C'")
1464
1465    def tearDown(self):
1466        self.c.close()
1467
1468    def testGetDecimalPoint(self):
1469        point = pg.get_decimal_point()
1470        # error if a parameter is passed
1471        self.assertRaises(TypeError, pg.get_decimal_point, point)
1472        self.assertIsInstance(point, str)
1473        self.assertEqual(point, '.')  # the default setting
1474        pg.set_decimal_point(',')
1475        try:
1476            r = pg.get_decimal_point()
1477        finally:
1478            pg.set_decimal_point(point)
1479        self.assertIsInstance(r, str)
1480        self.assertEqual(r, ',')
1481        pg.set_decimal_point("'")
1482        try:
1483            r = pg.get_decimal_point()
1484        finally:
1485            pg.set_decimal_point(point)
1486        self.assertIsInstance(r, str)
1487        self.assertEqual(r, "'")
1488        pg.set_decimal_point('')
1489        try:
1490            r = pg.get_decimal_point()
1491        finally:
1492            pg.set_decimal_point(point)
1493        self.assertIsNone(r)
1494        pg.set_decimal_point(None)
1495        try:
1496            r = pg.get_decimal_point()
1497        finally:
1498            pg.set_decimal_point(point)
1499        self.assertIsNone(r)
1500
1501    def testSetDecimalPoint(self):
1502        d = pg.Decimal
1503        point = pg.get_decimal_point()
1504        self.assertRaises(TypeError, pg.set_decimal_point)
1505        # error if decimal point is not a string
1506        self.assertRaises(TypeError, pg.set_decimal_point, 0)
1507        # error if more than one decimal point passed
1508        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
1509        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
1510        # error if decimal point is not a punctuation character
1511        self.assertRaises(TypeError, pg.set_decimal_point, '0')
1512        query = self.c.query
1513        # check that money values are interpreted as decimal values
1514        # only if decimal_point is set, and that the result is correct
1515        # only if it is set suitable for the current lc_monetary setting
1516        select_money = "select '34.25'::money"
1517        proper_money = d('34.25')
1518        bad_money = d('3425')
1519        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
1520        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
1521        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
1522        de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25',
1523            'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM')
1524        # first try with English localization (using the point)
1525        for lc in en_locales:
1526            try:
1527                query("set lc_monetary='%s'" % lc)
1528            except pg.DataError:
1529                pass
1530            else:
1531                break
1532        else:
1533            self.skipTest("cannot set English money locale")
1534        try:
1535            r = query(select_money)
1536        except pg.DataError:
1537            # this can happen if the currency signs cannot be
1538            # converted using the encoding of the test database
1539            self.skipTest("database does not support English money")
1540        pg.set_decimal_point(None)
1541        try:
1542            r = r.getresult()[0][0]
1543        finally:
1544            pg.set_decimal_point(point)
1545        self.assertIsInstance(r, str)
1546        self.assertIn(r, en_money)
1547        r = query(select_money)
1548        pg.set_decimal_point('')
1549        try:
1550            r = r.getresult()[0][0]
1551        finally:
1552            pg.set_decimal_point(point)
1553        self.assertIsInstance(r, str)
1554        self.assertIn(r, en_money)
1555        r = query(select_money)
1556        pg.set_decimal_point('.')
1557        try:
1558            r = r.getresult()[0][0]
1559        finally:
1560            pg.set_decimal_point(point)
1561        self.assertIsInstance(r, d)
1562        self.assertEqual(r, proper_money)
1563        r = query(select_money)
1564        pg.set_decimal_point(',')
1565        try:
1566            r = r.getresult()[0][0]
1567        finally:
1568            pg.set_decimal_point(point)
1569        self.assertIsInstance(r, d)
1570        self.assertEqual(r, bad_money)
1571        r = query(select_money)
1572        pg.set_decimal_point("'")
1573        try:
1574            r = r.getresult()[0][0]
1575        finally:
1576            pg.set_decimal_point(point)
1577        self.assertIsInstance(r, d)
1578        self.assertEqual(r, bad_money)
1579        # then try with German localization (using the comma)
1580        for lc in de_locales:
1581            try:
1582                query("set lc_monetary='%s'" % lc)
1583            except pg.DataError:
1584                pass
1585            else:
1586                break
1587        else:
1588            self.skipTest("cannot set German money locale")
1589        select_money = select_money.replace('.', ',')
1590        try:
1591            r = query(select_money)
1592        except pg.DataError:
1593            self.skipTest("database does not support English money")
1594        pg.set_decimal_point(None)
1595        try:
1596            r = r.getresult()[0][0]
1597        finally:
1598            pg.set_decimal_point(point)
1599        self.assertIsInstance(r, str)
1600        self.assertIn(r, de_money)
1601        r = query(select_money)
1602        pg.set_decimal_point('')
1603        try:
1604            r = r.getresult()[0][0]
1605        finally:
1606            pg.set_decimal_point(point)
1607        self.assertIsInstance(r, str)
1608        self.assertIn(r, de_money)
1609        r = query(select_money)
1610        pg.set_decimal_point(',')
1611        try:
1612            r = r.getresult()[0][0]
1613        finally:
1614            pg.set_decimal_point(point)
1615        self.assertIsInstance(r, d)
1616        self.assertEqual(r, proper_money)
1617        r = query(select_money)
1618        pg.set_decimal_point('.')
1619        try:
1620            r = r.getresult()[0][0]
1621        finally:
1622            pg.set_decimal_point(point)
1623        self.assertEqual(r, bad_money)
1624        r = query(select_money)
1625        pg.set_decimal_point("'")
1626        try:
1627            r = r.getresult()[0][0]
1628        finally:
1629            pg.set_decimal_point(point)
1630        self.assertEqual(r, bad_money)
1631
1632    def testGetDecimal(self):
1633        decimal_class = pg.get_decimal()
1634        # error if a parameter is passed
1635        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
1636        self.assertIs(decimal_class, pg.Decimal)  # the default setting
1637        pg.set_decimal(int)
1638        try:
1639            r = pg.get_decimal()
1640        finally:
1641            pg.set_decimal(decimal_class)
1642        self.assertIs(r, int)
1643        r = pg.get_decimal()
1644        self.assertIs(r, decimal_class)
1645
1646    def testSetDecimal(self):
1647        decimal_class = pg.get_decimal()
1648        # error if no parameter is passed
1649        self.assertRaises(TypeError, pg.set_decimal)
1650        query = self.c.query
1651        try:
1652            r = query("select 3425::numeric")
1653        except pg.DatabaseError:
1654            self.skipTest('database does not support numeric')
1655        r = r.getresult()[0][0]
1656        self.assertIsInstance(r, decimal_class)
1657        self.assertEqual(r, decimal_class('3425'))
1658        r = query("select 3425::numeric")
1659        pg.set_decimal(int)
1660        try:
1661            r = r.getresult()[0][0]
1662        finally:
1663            pg.set_decimal(decimal_class)
1664        self.assertNotIsInstance(r, decimal_class)
1665        self.assertIsInstance(r, int)
1666        self.assertEqual(r, int(3425))
1667
1668    def testGetBool(self):
1669        use_bool = pg.get_bool()
1670        # error if a parameter is passed
1671        self.assertRaises(TypeError, pg.get_bool, use_bool)
1672        self.assertIsInstance(use_bool, bool)
1673        self.assertIs(use_bool, True)  # the default setting
1674        pg.set_bool(False)
1675        try:
1676            r = pg.get_bool()
1677        finally:
1678            pg.set_bool(use_bool)
1679        self.assertIsInstance(r, bool)
1680        self.assertIs(r, False)
1681        pg.set_bool(True)
1682        try:
1683            r = pg.get_bool()
1684        finally:
1685            pg.set_bool(use_bool)
1686        self.assertIsInstance(r, bool)
1687        self.assertIs(r, True)
1688        pg.set_bool(0)
1689        try:
1690            r = pg.get_bool()
1691        finally:
1692            pg.set_bool(use_bool)
1693        self.assertIsInstance(r, bool)
1694        self.assertIs(r, False)
1695        pg.set_bool(1)
1696        try:
1697            r = pg.get_bool()
1698        finally:
1699            pg.set_bool(use_bool)
1700        self.assertIsInstance(r, bool)
1701        self.assertIs(r, True)
1702
1703    def testSetBool(self):
1704        use_bool = pg.get_bool()
1705        # error if no parameter is passed
1706        self.assertRaises(TypeError, pg.set_bool)
1707        query = self.c.query
1708        try:
1709            r = query("select true::bool")
1710        except pg.ProgrammingError:
1711            self.skipTest('database does not support bool')
1712        r = r.getresult()[0][0]
1713        self.assertIsInstance(r, bool)
1714        self.assertEqual(r, True)
1715        r = query("select true::bool")
1716        pg.set_bool(False)
1717        try:
1718            r = r.getresult()[0][0]
1719        finally:
1720            pg.set_bool(use_bool)
1721        self.assertIsInstance(r, str)
1722        self.assertIs(r, 't')
1723        r = query("select true::bool")
1724        pg.set_bool(True)
1725        try:
1726            r = r.getresult()[0][0]
1727        finally:
1728            pg.set_bool(use_bool)
1729        self.assertIsInstance(r, bool)
1730        self.assertIs(r, True)
1731
1732    def testGetByteEscaped(self):
1733        bytea_escaped = pg.get_bytea_escaped()
1734        # error if a parameter is passed
1735        self.assertRaises(TypeError, pg.get_bytea_escaped, bytea_escaped)
1736        self.assertIsInstance(bytea_escaped, bool)
1737        self.assertIs(bytea_escaped, False)  # the default setting
1738        pg.set_bytea_escaped(True)
1739        try:
1740            r = pg.get_bytea_escaped()
1741        finally:
1742            pg.set_bytea_escaped(bytea_escaped)
1743        self.assertIsInstance(r, bool)
1744        self.assertIs(r, True)
1745        pg.set_bytea_escaped(False)
1746        try:
1747            r = pg.get_bytea_escaped()
1748        finally:
1749            pg.set_bytea_escaped(bytea_escaped)
1750        self.assertIsInstance(r, bool)
1751        self.assertIs(r, False)
1752        pg.set_bytea_escaped(1)
1753        try:
1754            r = pg.get_bytea_escaped()
1755        finally:
1756            pg.set_bytea_escaped(bytea_escaped)
1757        self.assertIsInstance(r, bool)
1758        self.assertIs(r, True)
1759        pg.set_bytea_escaped(0)
1760        try:
1761            r = pg.get_bytea_escaped()
1762        finally:
1763            pg.set_bytea_escaped(bytea_escaped)
1764        self.assertIsInstance(r, bool)
1765        self.assertIs(r, False)
1766
1767    def testSetByteaEscaped(self):
1768        bytea_escaped = pg.get_bytea_escaped()
1769        # error if no parameter is passed
1770        self.assertRaises(TypeError, pg.set_bytea_escaped)
1771        query = self.c.query
1772        try:
1773            r = query("select 'data'::bytea")
1774        except pg.ProgrammingError:
1775            self.skipTest('database does not support bytea')
1776        r = r.getresult()[0][0]
1777        self.assertIsInstance(r, bytes)
1778        self.assertEqual(r, b'data')
1779        r = query("select 'data'::bytea")
1780        pg.set_bytea_escaped(True)
1781        try:
1782            r = r.getresult()[0][0]
1783        finally:
1784            pg.set_bytea_escaped(bytea_escaped)
1785        self.assertIsInstance(r, str)
1786        self.assertEqual(r, '\\x64617461')
1787        r = query("select 'data'::bytea")
1788        pg.set_bytea_escaped(False)
1789        try:
1790            r = r.getresult()[0][0]
1791        finally:
1792            pg.set_bytea_escaped(bytea_escaped)
1793        self.assertIsInstance(r, bytes)
1794        self.assertEqual(r, b'data')
1795
1796    def testGetNamedresult(self):
1797        namedresult = pg.get_namedresult()
1798        # error if a parameter is passed
1799        self.assertRaises(TypeError, pg.get_namedresult, namedresult)
1800        self.assertIs(namedresult, pg._namedresult)  # the default setting
1801
1802    def testSetNamedresult(self):
1803        namedresult = pg.get_namedresult()
1804        self.assertTrue(callable(namedresult))
1805
1806        query = self.c.query
1807
1808        r = query("select 1 as x, 2 as y").namedresult()[0]
1809        self.assertIsInstance(r, tuple)
1810        self.assertEqual(r, (1, 2))
1811        self.assertIsNot(type(r), tuple)
1812        self.assertEqual(r._fields, ('x', 'y'))
1813        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1814        self.assertEqual(r.__class__.__name__, 'Row')
1815
1816        def listresult(q):
1817            return [list(row) for row in q.getresult()]
1818
1819        pg.set_namedresult(listresult)
1820        try:
1821            r = pg.get_namedresult()
1822            self.assertIs(r, listresult)
1823            r = query("select 1 as x, 2 as y").namedresult()[0]
1824            self.assertIsInstance(r, list)
1825            self.assertEqual(r, [1, 2])
1826            self.assertIsNot(type(r), tuple)
1827            self.assertFalse(hasattr(r, '_fields'))
1828            self.assertNotEqual(r.__class__.__name__, 'Row')
1829        finally:
1830            pg.set_namedresult(namedresult)
1831
1832        r = pg.get_namedresult()
1833        self.assertIs(r, namedresult)
1834
1835    def testSetRowFactorySize(self):
1836        try:
1837            from functools import lru_cache
1838        except ImportError:  # Python < 3.2
1839            lru_cache = None
1840        queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc']
1841        query = self.c.query
1842        for maxsize in (None, 0, 1, 2, 3, 10, 1024):
1843            pg.set_row_factory_size(maxsize)
1844            for i in range(3):
1845                for q in queries:
1846                    r = query(q).namedresult()[0]
1847                    if q.endswith('abc'):
1848                        self.assertEqual(r, (123,))
1849                        self.assertEqual(r._fields, ('abc',))
1850                    else:
1851                        self.assertEqual(r, (1, 2, 3))
1852                        self.assertEqual(r._fields, ('a', 'b', 'c'))
1853            if lru_cache:
1854                info = pg._row_factory.cache_info()
1855                self.assertEqual(info.maxsize, maxsize)
1856                self.assertEqual(info.hits + info.misses, 6)
1857                self.assertEqual(info.hits,
1858                    0 if maxsize is not None and maxsize < 2 else 4)
1859
1860
1861class TestStandaloneEscapeFunctions(unittest.TestCase):
1862    """Test pg escape functions.
1863
1864    The libpq interface memorizes some parameters of the last opened
1865    connection that influence the result of these functions.  Therefore
1866    we need to open a connection with fixed parameters prior to testing
1867    in order to ensure that the tests always run under the same conditions.
1868    That's why these tests are included in this test module.
1869    """
1870
1871    cls_set_up = False
1872
1873    @classmethod
1874    def setUpClass(cls):
1875        db = connect()
1876        query = db.query
1877        query('set client_encoding=sql_ascii')
1878        query('set standard_conforming_strings=off')
1879        try:
1880            query('set bytea_output=escape')
1881        except pg.ProgrammingError:
1882            if db.server_version >= 90000:
1883                raise  # ignore for older server versions
1884        db.close()
1885        cls.cls_set_up = True
1886
1887    def testEscapeString(self):
1888        self.assertTrue(self.cls_set_up)
1889        f = pg.escape_string
1890        r = f(b'plain')
1891        self.assertIsInstance(r, bytes)
1892        self.assertEqual(r, b'plain')
1893        r = f(u'plain')
1894        self.assertIsInstance(r, unicode)
1895        self.assertEqual(r, u'plain')
1896        r = f(u"das is' kÀse".encode('utf-8'))
1897        self.assertIsInstance(r, bytes)
1898        self.assertEqual(r, u"das is'' kÀse".encode('utf-8'))
1899        r = f(u"that's cheesy")
1900        self.assertIsInstance(r, unicode)
1901        self.assertEqual(r, u"that''s cheesy")
1902        r = f(r"It's bad to have a \ inside.")
1903        self.assertEqual(r, r"It''s bad to have a \\ inside.")
1904
1905    def testEscapeBytea(self):
1906        self.assertTrue(self.cls_set_up)
1907        f = pg.escape_bytea
1908        r = f(b'plain')
1909        self.assertIsInstance(r, bytes)
1910        self.assertEqual(r, b'plain')
1911        r = f(u'plain')
1912        self.assertIsInstance(r, unicode)
1913        self.assertEqual(r, u'plain')
1914        r = f(u"das is' kÀse".encode('utf-8'))
1915        self.assertIsInstance(r, bytes)
1916        self.assertEqual(r, b"das is'' k\\\\303\\\\244se")
1917        r = f(u"that's cheesy")
1918        self.assertIsInstance(r, unicode)
1919        self.assertEqual(r, u"that''s cheesy")
1920        r = f(b'O\x00ps\xff!')
1921        self.assertEqual(r, b'O\\\\000ps\\\\377!')
1922
1923
1924if __name__ == '__main__':
1925    unittest.main()
Note: See TracBrowser for help on using the repository browser.