source: trunk/tests/test_classic_connection.py @ 957

Last change on this file since 957 was 957, checked in by cito, 9 months ago

Add documentation for prepared statements

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 72.6 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17import threading
18import time
19import os
20
21from collections import namedtuple
22from decimal import Decimal
23
24import pg  # the module under test
25
26# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
27# get our information from that.  Otherwise we use the defaults.
28# These tests should be run with various PostgreSQL versions and databases
29# created with different encodings and locales.  Particularly, make sure the
30# tests are running against databases created with both SQL_ASCII and UTF8.
31dbname = 'unittest'
32dbhost = None
33dbport = 5432
34
35try:
36    from .LOCAL_PyGreSQL import *
37except (ImportError, ValueError):
38    try:
39        from LOCAL_PyGreSQL import *
40    except ImportError:
41        pass
42
43try:
44    long
45except NameError:  # Python >= 3.0
46    long = int
47
48try:
49    unicode
50except NameError:  # Python >= 3.0
51    unicode = str
52
53unicode_strings = str is not bytes
54
55windows = os.name == 'nt'
56
57# There is a known a bug in libpq under Windows which can cause
58# the interface to crash when calling PQhost():
59do_not_ask_for_host = windows
60do_not_ask_for_host_reason = 'libpq issue on Windows'
61
62
63def connect():
64    """Create a basic pg connection to the test database."""
65    connection = pg.connect(dbname, dbhost, dbport)
66    connection.query("set client_min_messages=warning")
67    return connection
68
69
70class TestCanConnect(unittest.TestCase):
71    """Test whether a basic connection to PostgreSQL is possible."""
72
73    def testCanConnect(self):
74        try:
75            connection = connect()
76        except pg.Error as error:
77            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
78        try:
79            connection.close()
80        except pg.Error:
81            self.fail('Cannot close the database connection')
82
83
84class TestConnectObject(unittest.TestCase):
85    """Test existence of basic pg connection methods."""
86
87    def setUp(self):
88        self.connection = connect()
89
90    def tearDown(self):
91        try:
92            self.connection.close()
93        except pg.InternalError:
94            pass
95
96    def is_method(self, attribute):
97        """Check if given attribute on the connection is a method."""
98        if do_not_ask_for_host and attribute == 'host':
99            return False
100        return callable(getattr(self.connection, attribute))
101
102    def testClassName(self):
103        self.assertEqual(self.connection.__class__.__name__, 'Connection')
104
105    def testModuleName(self):
106        self.assertEqual(self.connection.__class__.__module__, 'pg')
107
108    def testStr(self):
109        r = str(self.connection)
110        self.assertTrue(r.startswith('<pg.Connection object'), r)
111
112    def testRepr(self):
113        r = repr(self.connection)
114        self.assertTrue(r.startswith('<pg.Connection object'), r)
115
116    def testAllConnectAttributes(self):
117        attributes = '''db error host options port
118            protocol_version server_version status user'''.split()
119        connection_attributes = [a for a in dir(self.connection)
120            if not a.startswith('__') and not self.is_method(a)]
121        self.assertEqual(attributes, connection_attributes)
122
123    def testAllConnectMethods(self):
124        methods = '''cancel close date_format describe_prepared endcopy
125            escape_bytea escape_identifier escape_literal escape_string
126            fileno get_cast_hook get_notice_receiver getline getlo getnotify
127            inserttable locreate loimport parameter
128            prepare putline query query_prepared reset
129            set_cast_hook set_notice_receiver source transaction'''.split()
130        connection_methods = [a for a in dir(self.connection)
131            if not a.startswith('__') and self.is_method(a)]
132        self.assertEqual(methods, connection_methods)
133
134    def testAttributeDb(self):
135        self.assertEqual(self.connection.db, dbname)
136
137    def testAttributeError(self):
138        error = self.connection.error
139        self.assertTrue(not error or 'krb5_' in error)
140
141    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
142    def testAttributeHost(self):
143        if dbhost and not dbhost.startswith('/'):
144            host = dbhost
145        else:
146            host = 'localhost'
147        self.assertIsInstance(self.connection.host, str)
148        self.assertEqual(self.connection.host, host)
149
150    def testAttributeOptions(self):
151        no_options = ''
152        self.assertEqual(self.connection.options, no_options)
153
154    def testAttributePort(self):
155        def_port = 5432
156        self.assertIsInstance(self.connection.port, int)
157        self.assertEqual(self.connection.port, dbport or def_port)
158
159    def testAttributeProtocolVersion(self):
160        protocol_version = self.connection.protocol_version
161        self.assertIsInstance(protocol_version, int)
162        self.assertTrue(2 <= protocol_version < 4)
163
164    def testAttributeServerVersion(self):
165        server_version = self.connection.server_version
166        self.assertIsInstance(server_version, int)
167        self.assertTrue(90000 <= server_version < 110000)
168
169    def testAttributeStatus(self):
170        status_ok = 1
171        self.assertIsInstance(self.connection.status, int)
172        self.assertEqual(self.connection.status, status_ok)
173
174    def testAttributeUser(self):
175        no_user = 'Deprecated facility'
176        user = self.connection.user
177        self.assertTrue(user)
178        self.assertIsInstance(user, str)
179        self.assertNotEqual(user, no_user)
180
181    def testMethodQuery(self):
182        query = self.connection.query
183        query("select 1+1")
184        query("select 1+$1", (1,))
185        query("select 1+$1+$2", (2, 3))
186        query("select 1+$1+$2", [2, 3])
187
188    def testMethodQueryEmpty(self):
189        self.assertRaises(ValueError, self.connection.query, '')
190
191    def testMethodEndcopy(self):
192        try:
193            self.connection.endcopy()
194        except IOError:
195            pass
196
197    def testMethodClose(self):
198        self.connection.close()
199        try:
200            self.connection.reset()
201        except (pg.Error, TypeError):
202            pass
203        else:
204            self.fail('Reset should give an error for a closed connection')
205        self.assertRaises(pg.InternalError, self.connection.close)
206        try:
207            self.connection.query('select 1')
208        except (pg.Error, TypeError):
209            pass
210        else:
211            self.fail('Query should give an error for a closed connection')
212        self.connection = connect()
213
214    def testMethodReset(self):
215        query = self.connection.query
216        # check that client encoding gets reset
217        encoding = query('show client_encoding').getresult()[0][0].upper()
218        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
219        self.assertNotEqual(encoding, changed_encoding)
220        self.connection.query("set client_encoding=%s" % changed_encoding)
221        new_encoding = query('show client_encoding').getresult()[0][0].upper()
222        self.assertEqual(new_encoding, changed_encoding)
223        self.connection.reset()
224        new_encoding = query('show client_encoding').getresult()[0][0].upper()
225        self.assertNotEqual(new_encoding, changed_encoding)
226        self.assertEqual(new_encoding, encoding)
227
228    def testMethodCancel(self):
229        r = self.connection.cancel()
230        self.assertIsInstance(r, int)
231        self.assertEqual(r, 1)
232
233    def testCancelLongRunningThread(self):
234        errors = []
235
236        def sleep():
237            try:
238                self.connection.query('select pg_sleep(5)').getresult()
239            except pg.DatabaseError as error:
240                errors.append(str(error))
241
242        thread = threading.Thread(target=sleep)
243        t1 = time.time()
244        thread.start()  # run the query
245        while 1:  # make sure the query is really running
246            time.sleep(0.1)
247            if thread.is_alive() or time.time() - t1 > 5:
248                break
249        r = self.connection.cancel()  # cancel the running query
250        thread.join()  # wait for the thread to end
251        t2 = time.time()
252
253        self.assertIsInstance(r, int)
254        self.assertEqual(r, 1)  # return code should be 1
255        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
256        self.assertTrue(errors)
257
258    def testMethodFileNo(self):
259        r = self.connection.fileno()
260        self.assertIsInstance(r, int)
261        self.assertGreaterEqual(r, 0)
262
263    def testMethodTransaction(self):
264        transaction = self.connection.transaction
265        self.assertRaises(TypeError, transaction, None)
266        self.assertEqual(transaction(), pg.TRANS_IDLE)
267        self.connection.query('begin')
268        self.assertEqual(transaction(), pg.TRANS_INTRANS)
269        self.connection.query('rollback')
270        self.assertEqual(transaction(), pg.TRANS_IDLE)
271
272    def testMethodParameter(self):
273        parameter = self.connection.parameter
274        query = self.connection.query
275        self.assertRaises(TypeError, parameter)
276        r = parameter('this server setting does not exist')
277        self.assertIsNone(r)
278        s = query('show server_version').getresult()[0][0].upper()
279        self.assertIsNotNone(s)
280        r = parameter('server_version')
281        self.assertEqual(r, s)
282        s = query('show server_encoding').getresult()[0][0].upper()
283        self.assertIsNotNone(s)
284        r = parameter('server_encoding')
285        self.assertEqual(r, s)
286        s = query('show client_encoding').getresult()[0][0].upper()
287        self.assertIsNotNone(s)
288        r = parameter('client_encoding')
289        self.assertEqual(r, s)
290        s = query('show server_encoding').getresult()[0][0].upper()
291        self.assertIsNotNone(s)
292        r = parameter('server_encoding')
293        self.assertEqual(r, s)
294
295
296class TestSimpleQueries(unittest.TestCase):
297    """Test simple queries via a basic pg connection."""
298
299    def setUp(self):
300        self.c = connect()
301
302    def tearDown(self):
303        self.doCleanups()
304        self.c.close()
305
306    def testClassName(self):
307        r = self.c.query("select 1")
308        self.assertEqual(r.__class__.__name__, 'Query')
309
310    def testModuleName(self):
311        r = self.c.query("select 1")
312        self.assertEqual(r.__class__.__module__, 'pg')
313
314    def testStr(self):
315        q = ("select 1 as a, 'hello' as h, 'w' as world"
316            " union select 2, 'xyz', 'uvw'")
317        r = self.c.query(q)
318        self.assertEqual(str(r),
319            'a|  h  |world\n'
320            '-+-----+-----\n'
321            '1|hello|w    \n'
322            '2|xyz  |uvw  \n'
323            '(2 rows)')
324
325    def testRepr(self):
326        r = repr(self.c.query("select 1"))
327        self.assertTrue(r.startswith('<pg.Query object'), r)
328
329    def testSelect0(self):
330        q = "select 0"
331        self.c.query(q)
332
333    def testSelect0Semicolon(self):
334        q = "select 0;"
335        self.c.query(q)
336
337    def testSelectDotSemicolon(self):
338        q = "select .;"
339        self.assertRaises(pg.DatabaseError, self.c.query, q)
340
341    def testGetresult(self):
342        q = "select 0"
343        result = [(0,)]
344        r = self.c.query(q).getresult()
345        self.assertIsInstance(r, list)
346        v = r[0]
347        self.assertIsInstance(v, tuple)
348        self.assertIsInstance(v[0], int)
349        self.assertEqual(r, result)
350
351    def testGetresultLong(self):
352        q = "select 9876543210"
353        result = long(9876543210)
354        self.assertIsInstance(result, long)
355        v = self.c.query(q).getresult()[0][0]
356        self.assertIsInstance(v, long)
357        self.assertEqual(v, result)
358
359    def testGetresultDecimal(self):
360        q = "select 98765432109876543210"
361        result = Decimal(98765432109876543210)
362        v = self.c.query(q).getresult()[0][0]
363        self.assertIsInstance(v, Decimal)
364        self.assertEqual(v, result)
365
366    def testGetresultString(self):
367        result = 'Hello, world!'
368        q = "select '%s'" % result
369        v = self.c.query(q).getresult()[0][0]
370        self.assertIsInstance(v, str)
371        self.assertEqual(v, result)
372
373    def testDictresult(self):
374        q = "select 0 as alias0"
375        result = [{'alias0': 0}]
376        r = self.c.query(q).dictresult()
377        self.assertIsInstance(r, list)
378        v = r[0]
379        self.assertIsInstance(v, dict)
380        self.assertIsInstance(v['alias0'], int)
381        self.assertEqual(r, result)
382
383    def testDictresultLong(self):
384        q = "select 9876543210 as longjohnsilver"
385        result = long(9876543210)
386        self.assertIsInstance(result, long)
387        v = self.c.query(q).dictresult()[0]['longjohnsilver']
388        self.assertIsInstance(v, long)
389        self.assertEqual(v, result)
390
391    def testDictresultDecimal(self):
392        q = "select 98765432109876543210 as longjohnsilver"
393        result = Decimal(98765432109876543210)
394        v = self.c.query(q).dictresult()[0]['longjohnsilver']
395        self.assertIsInstance(v, Decimal)
396        self.assertEqual(v, result)
397
398    def testDictresultString(self):
399        result = 'Hello, world!'
400        q = "select '%s' as greeting" % result
401        v = self.c.query(q).dictresult()[0]['greeting']
402        self.assertIsInstance(v, str)
403        self.assertEqual(v, result)
404
405    def testNamedresult(self):
406        q = "select 0 as alias0"
407        result = [(0,)]
408        r = self.c.query(q).namedresult()
409        self.assertEqual(r, result)
410        v = r[0]
411        self.assertEqual(v._fields, ('alias0',))
412        self.assertEqual(v.alias0, 0)
413
414    def testNamedresultWithGoodFieldnames(self):
415        q = 'select 1 as snake_case_alias, 2 as "CamelCaseAlias"'
416        result = [(1, 2)]
417        r = self.c.query(q).namedresult()
418        self.assertEqual(r, result)
419        v = r[0]
420        self.assertEqual(v._fields, ('snake_case_alias', 'CamelCaseAlias'))
421
422    def testNamedresultWithBadFieldnames(self):
423        try:
424            r = namedtuple('Bad', ['?'] * 6, rename=True)
425        except TypeError:  # Python 2.6 or 3.0
426            fields = tuple('column_%d' % n for n in range(6))
427        else:
428            fields = r._fields
429        q = ('select 3 as "0alias", 4 as _alias, 5 as "alias$", 6 as "alias?",'
430            ' 7 as "kebap-case-alias", 8 as break, 9 as and_a_good_one')
431        result = [tuple(range(3, 10))]
432        r = self.c.query(q).namedresult()
433        self.assertEqual(r, result)
434        v = r[0]
435        self.assertEqual(v._fields[:6], fields)
436        self.assertEqual(v._fields[6], 'and_a_good_one')
437
438    def testGet3Cols(self):
439        q = "select 1,2,3"
440        result = [(1, 2, 3)]
441        r = self.c.query(q).getresult()
442        self.assertEqual(r, result)
443
444    def testGet3DictCols(self):
445        q = "select 1 as a,2 as b,3 as c"
446        result = [dict(a=1, b=2, c=3)]
447        r = self.c.query(q).dictresult()
448        self.assertEqual(r, result)
449
450    def testGet3NamedCols(self):
451        q = "select 1 as a,2 as b,3 as c"
452        result = [(1, 2, 3)]
453        r = self.c.query(q).namedresult()
454        self.assertEqual(r, result)
455        v = r[0]
456        self.assertEqual(v._fields, ('a', 'b', 'c'))
457        self.assertEqual(v.b, 2)
458
459    def testGet3Rows(self):
460        q = "select 3 union select 1 union select 2 order by 1"
461        result = [(1,), (2,), (3,)]
462        r = self.c.query(q).getresult()
463        self.assertEqual(r, result)
464
465    def testGet3DictRows(self):
466        q = ("select 3 as alias3"
467            " union select 1 union select 2 order by 1")
468        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
469        r = self.c.query(q).dictresult()
470        self.assertEqual(r, result)
471
472    def testGet3NamedRows(self):
473        q = ("select 3 as alias3"
474            " union select 1 union select 2 order by 1")
475        result = [(1,), (2,), (3,)]
476        r = self.c.query(q).namedresult()
477        self.assertEqual(r, result)
478        for v in r:
479            self.assertEqual(v._fields, ('alias3',))
480
481    def testDictresultNames(self):
482        q = "select 'MixedCase' as MixedCaseAlias"
483        result = [{'mixedcasealias': 'MixedCase'}]
484        r = self.c.query(q).dictresult()
485        self.assertEqual(r, result)
486        q = "select 'MixedCase' as \"MixedCaseAlias\""
487        result = [{'MixedCaseAlias': 'MixedCase'}]
488        r = self.c.query(q).dictresult()
489        self.assertEqual(r, result)
490
491    def testNamedresultNames(self):
492        q = "select 'MixedCase' as MixedCaseAlias"
493        result = [('MixedCase',)]
494        r = self.c.query(q).namedresult()
495        self.assertEqual(r, result)
496        v = r[0]
497        self.assertEqual(v._fields, ('mixedcasealias',))
498        self.assertEqual(v.mixedcasealias, 'MixedCase')
499        q = "select 'MixedCase' as \"MixedCaseAlias\""
500        r = self.c.query(q).namedresult()
501        self.assertEqual(r, result)
502        v = r[0]
503        self.assertEqual(v._fields, ('MixedCaseAlias',))
504        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
505
506    def testBigGetresult(self):
507        num_cols = 100
508        num_rows = 100
509        q = "select " + ','.join(map(str, range(num_cols)))
510        q = ' union all '.join((q,) * num_rows)
511        r = self.c.query(q).getresult()
512        result = [tuple(range(num_cols))] * num_rows
513        self.assertEqual(r, result)
514
515    def testListfields(self):
516        q = ('select 0 as a, 0 as b, 0 as c,'
517            ' 0 as c, 0 as b, 0 as a,'
518            ' 0 as lowercase, 0 as UPPERCASE,'
519            ' 0 as MixedCase, 0 as "MixedCase",'
520            ' 0 as a_long_name_with_underscores,'
521            ' 0 as "A long name with Blanks"')
522        r = self.c.query(q).listfields()
523        self.assertIsInstance(r, tuple)
524        result = ('a', 'b', 'c', 'c', 'b', 'a',
525            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
526            'a_long_name_with_underscores',
527            'A long name with Blanks')
528        self.assertEqual(r, result)
529
530    def testFieldname(self):
531        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
532        r = self.c.query(q).fieldname(2)
533        self.assertEqual(r, 'x')
534        r = self.c.query(q).fieldname(3)
535        self.assertEqual(r, 'y')
536
537    def testFieldnum(self):
538        q = "select 1 as x"
539        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
540        q = "select 1 as x"
541        r = self.c.query(q).fieldnum('x')
542        self.assertIsInstance(r, int)
543        self.assertEqual(r, 0)
544        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
545        r = self.c.query(q).fieldnum('x')
546        self.assertIsInstance(r, int)
547        self.assertEqual(r, 2)
548        r = self.c.query(q).fieldnum('y')
549        self.assertIsInstance(r, int)
550        self.assertEqual(r, 3)
551
552    def testNtuples(self):
553        q = "select 1 where false"
554        r = self.c.query(q).ntuples()
555        self.assertIsInstance(r, int)
556        self.assertEqual(r, 0)
557        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
558            " union select 5 as a, 6 as b, 7 as c, 8 as d")
559        r = self.c.query(q).ntuples()
560        self.assertIsInstance(r, int)
561        self.assertEqual(r, 2)
562        q = ("select 1 union select 2 union select 3"
563            " union select 4 union select 5 union select 6")
564        r = self.c.query(q).ntuples()
565        self.assertIsInstance(r, int)
566        self.assertEqual(r, 6)
567
568    def testQuery(self):
569        query = self.c.query
570        query("drop table if exists test_table")
571        self.addCleanup(query, "drop table test_table")
572        q = "create table test_table (n integer) with oids"
573        r = query(q)
574        self.assertIsNone(r)
575        q = "insert into test_table values (1)"
576        r = query(q)
577        self.assertIsInstance(r, int)
578        q = "insert into test_table select 2"
579        r = query(q)
580        self.assertIsInstance(r, int)
581        oid = r
582        q = "select oid from test_table where n=2"
583        r = query(q).getresult()
584        self.assertEqual(len(r), 1)
585        r = r[0]
586        self.assertEqual(len(r), 1)
587        r = r[0]
588        self.assertIsInstance(r, int)
589        self.assertEqual(r, oid)
590        q = "insert into test_table select 3 union select 4 union select 5"
591        r = query(q)
592        self.assertIsInstance(r, str)
593        self.assertEqual(r, '3')
594        q = "update test_table set n=4 where n<5"
595        r = query(q)
596        self.assertIsInstance(r, str)
597        self.assertEqual(r, '4')
598        q = "delete from test_table"
599        r = query(q)
600        self.assertIsInstance(r, str)
601        self.assertEqual(r, '5')
602
603
604class TestUnicodeQueries(unittest.TestCase):
605    """Test unicode strings as queries via a basic pg connection."""
606
607    def setUp(self):
608        self.c = connect()
609        self.c.query('set client_encoding=utf8')
610
611    def tearDown(self):
612        self.c.close()
613
614    def testGetresulAscii(self):
615        result = u'Hello, world!'
616        q = u"select '%s'" % result
617        v = self.c.query(q).getresult()[0][0]
618        self.assertIsInstance(v, str)
619        self.assertEqual(v, result)
620
621    def testDictresulAscii(self):
622        result = u'Hello, world!'
623        q = u"select '%s' as greeting" % result
624        v = self.c.query(q).dictresult()[0]['greeting']
625        self.assertIsInstance(v, str)
626        self.assertEqual(v, result)
627
628    def testGetresultUtf8(self):
629        result = u'Hello, wörld & ЌОр!'
630        q = u"select '%s'" % result
631        if not unicode_strings:
632            result = result.encode('utf8')
633        # pass the query as unicode
634        try:
635            v = self.c.query(q).getresult()[0][0]
636        except(pg.DataError, pg.NotSupportedError):
637            self.skipTest("database does not support utf8")
638        self.assertIsInstance(v, str)
639        self.assertEqual(v, result)
640        q = q.encode('utf8')
641        # pass the query as bytes
642        v = self.c.query(q).getresult()[0][0]
643        self.assertIsInstance(v, str)
644        self.assertEqual(v, result)
645
646    def testDictresultUtf8(self):
647        result = u'Hello, wörld & ЌОр!'
648        q = u"select '%s' as greeting" % result
649        if not unicode_strings:
650            result = result.encode('utf8')
651        try:
652            v = self.c.query(q).dictresult()[0]['greeting']
653        except (pg.DataError, pg.NotSupportedError):
654            self.skipTest("database does not support utf8")
655        self.assertIsInstance(v, str)
656        self.assertEqual(v, result)
657        q = q.encode('utf8')
658        v = self.c.query(q).dictresult()[0]['greeting']
659        self.assertIsInstance(v, str)
660        self.assertEqual(v, result)
661
662    def testDictresultLatin1(self):
663        try:
664            self.c.query('set client_encoding=latin1')
665        except (pg.DataError, pg.NotSupportedError):
666            self.skipTest("database does not support latin1")
667        result = u'Hello, wörld!'
668        q = u"select '%s'" % result
669        if not unicode_strings:
670            result = result.encode('latin1')
671        v = self.c.query(q).getresult()[0][0]
672        self.assertIsInstance(v, str)
673        self.assertEqual(v, result)
674        q = q.encode('latin1')
675        v = self.c.query(q).getresult()[0][0]
676        self.assertIsInstance(v, str)
677        self.assertEqual(v, result)
678
679    def testDictresultLatin1(self):
680        try:
681            self.c.query('set client_encoding=latin1')
682        except (pg.DataError, pg.NotSupportedError):
683            self.skipTest("database does not support latin1")
684        result = u'Hello, wörld!'
685        q = u"select '%s' as greeting" % result
686        if not unicode_strings:
687            result = result.encode('latin1')
688        v = self.c.query(q).dictresult()[0]['greeting']
689        self.assertIsInstance(v, str)
690        self.assertEqual(v, result)
691        q = q.encode('latin1')
692        v = self.c.query(q).dictresult()[0]['greeting']
693        self.assertIsInstance(v, str)
694        self.assertEqual(v, result)
695
696    def testGetresultCyrillic(self):
697        try:
698            self.c.query('set client_encoding=iso_8859_5')
699        except (pg.DataError, pg.NotSupportedError):
700            self.skipTest("database does not support cyrillic")
701        result = u'Hello, ЌОр!'
702        q = u"select '%s'" % result
703        if not unicode_strings:
704            result = result.encode('cyrillic')
705        v = self.c.query(q).getresult()[0][0]
706        self.assertIsInstance(v, str)
707        self.assertEqual(v, result)
708        q = q.encode('cyrillic')
709        v = self.c.query(q).getresult()[0][0]
710        self.assertIsInstance(v, str)
711        self.assertEqual(v, result)
712
713    def testDictresultCyrillic(self):
714        try:
715            self.c.query('set client_encoding=iso_8859_5')
716        except (pg.DataError, pg.NotSupportedError):
717            self.skipTest("database does not support cyrillic")
718        result = u'Hello, ЌОр!'
719        q = u"select '%s' as greeting" % result
720        if not unicode_strings:
721            result = result.encode('cyrillic')
722        v = self.c.query(q).dictresult()[0]['greeting']
723        self.assertIsInstance(v, str)
724        self.assertEqual(v, result)
725        q = q.encode('cyrillic')
726        v = self.c.query(q).dictresult()[0]['greeting']
727        self.assertIsInstance(v, str)
728        self.assertEqual(v, result)
729
730    def testGetresultLatin9(self):
731        try:
732            self.c.query('set client_encoding=latin9')
733        except (pg.DataError, pg.NotSupportedError):
734            self.skipTest("database does not support latin9")
735        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
736        q = u"select '%s'" % result
737        if not unicode_strings:
738            result = result.encode('latin9')
739        v = self.c.query(q).getresult()[0][0]
740        self.assertIsInstance(v, str)
741        self.assertEqual(v, result)
742        q = q.encode('latin9')
743        v = self.c.query(q).getresult()[0][0]
744        self.assertIsInstance(v, str)
745        self.assertEqual(v, result)
746
747    def testDictresultLatin9(self):
748        try:
749            self.c.query('set client_encoding=latin9')
750        except (pg.DataError, pg.NotSupportedError):
751            self.skipTest("database does not support latin9")
752        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
753        q = u"select '%s' as menu" % result
754        if not unicode_strings:
755            result = result.encode('latin9')
756        v = self.c.query(q).dictresult()[0]['menu']
757        self.assertIsInstance(v, str)
758        self.assertEqual(v, result)
759        q = q.encode('latin9')
760        v = self.c.query(q).dictresult()[0]['menu']
761        self.assertIsInstance(v, str)
762        self.assertEqual(v, result)
763
764
765class TestParamQueries(unittest.TestCase):
766    """Test queries with parameters via a basic pg connection."""
767
768    def setUp(self):
769        self.c = connect()
770        self.c.query('set client_encoding=utf8')
771
772    def tearDown(self):
773        self.c.close()
774
775    def testQueryWithNoneParam(self):
776        self.assertRaises(TypeError, self.c.query, "select $1", None)
777        self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None)
778        self.assertEqual(self.c.query("select $1::integer", (None,)
779            ).getresult(), [(None,)])
780        self.assertEqual(self.c.query("select $1::text", [None]
781            ).getresult(), [(None,)])
782        self.assertEqual(self.c.query("select $1::text", [[None]]
783            ).getresult(), [(None,)])
784
785    def testQueryWithBoolParams(self, bool_enabled=None):
786        query = self.c.query
787        if bool_enabled is not None:
788            bool_enabled_default = pg.get_bool()
789            pg.set_bool(bool_enabled)
790        try:
791            bool_on = bool_enabled or bool_enabled is None
792            v_false, v_true = (False, True) if bool_on else 'ft'
793            r_false, r_true = [(v_false,)], [(v_true,)]
794            self.assertEqual(query("select false").getresult(), r_false)
795            self.assertEqual(query("select true").getresult(), r_true)
796            q = "select $1::bool"
797            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
798            self.assertEqual(query(q, ('f',)).getresult(), r_false)
799            self.assertEqual(query(q, ('t',)).getresult(), r_true)
800            self.assertEqual(query(q, ('false',)).getresult(), r_false)
801            self.assertEqual(query(q, ('true',)).getresult(), r_true)
802            self.assertEqual(query(q, ('n',)).getresult(), r_false)
803            self.assertEqual(query(q, ('y',)).getresult(), r_true)
804            self.assertEqual(query(q, (0,)).getresult(), r_false)
805            self.assertEqual(query(q, (1,)).getresult(), r_true)
806            self.assertEqual(query(q, (False,)).getresult(), r_false)
807            self.assertEqual(query(q, (True,)).getresult(), r_true)
808        finally:
809            if bool_enabled is not None:
810                pg.set_bool(bool_enabled_default)
811
812    def testQueryWithBoolParamsNotDefault(self):
813        self.testQueryWithBoolParams(bool_enabled=not pg.get_bool())
814
815    def testQueryWithIntParams(self):
816        query = self.c.query
817        self.assertEqual(query("select 1+1").getresult(), [(2,)])
818        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
819        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
820        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
821        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
822        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
823            [(Decimal('2'),)])
824        self.assertEqual(query("select 1, $1::integer", (2,)
825            ).getresult(), [(1, 2)])
826        self.assertEqual(query("select 1 union select $1::integer", (2,)
827            ).getresult(), [(1,), (2,)])
828        self.assertEqual(query("select $1::integer+$2", (1, 2)
829            ).getresult(), [(3,)])
830        self.assertEqual(query("select $1::integer+$2", [1, 2]
831            ).getresult(), [(3,)])
832        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))
833            ).getresult(), [(15,)])
834
835    def testQueryWithStrParams(self):
836        query = self.c.query
837        self.assertEqual(query("select $1||', world!'", ('Hello',)
838            ).getresult(), [('Hello, world!',)])
839        self.assertEqual(query("select $1||', world!'", ['Hello']
840            ).getresult(), [('Hello, world!',)])
841        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
842            ).getresult(), [('Hello, world!',)])
843        self.assertEqual(query("select $1::text", ('Hello, world!',)
844            ).getresult(), [('Hello, world!',)])
845        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
846            ).getresult(), [('Hello', 'world')])
847        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
848            ).getresult(), [('Hello', 'world')])
849        self.assertEqual(query("select $1::text union select $2::text",
850            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
851        try:
852            query("select 'wörld'")
853        except (pg.DataError, pg.NotSupportedError):
854            self.skipTest('database does not support utf8')
855        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
856            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
857
858    def testQueryWithUnicodeParams(self):
859        query = self.c.query
860        try:
861            query('set client_encoding=utf8')
862            query("select 'wörld'").getresult()[0][0] == 'wörld'
863        except (pg.DataError, pg.NotSupportedError):
864            self.skipTest("database does not support utf8")
865        self.assertEqual(query("select $1||', '||$2||'!'",
866            ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)])
867
868    def testQueryWithUnicodeParamsLatin1(self):
869        query = self.c.query
870        try:
871            query('set client_encoding=latin1')
872            query("select 'wörld'").getresult()[0][0] == 'wörld'
873        except (pg.DataError, pg.NotSupportedError):
874            self.skipTest("database does not support latin1")
875        r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult()
876        if unicode_strings:
877            self.assertEqual(r, [('Hello, wörld!',)])
878        else:
879            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
880        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
881            ('Hello', u'ЌОр'))
882        query('set client_encoding=iso_8859_1')
883        r = query("select $1||', '||$2||'!'",
884            ('Hello', u'wörld')).getresult()
885        if unicode_strings:
886            self.assertEqual(r, [('Hello, wörld!',)])
887        else:
888            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
889        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
890            ('Hello', u'ЌОр'))
891        query('set client_encoding=sql_ascii')
892        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
893            ('Hello', u'wörld'))
894
895    def testQueryWithUnicodeParamsCyrillic(self):
896        query = self.c.query
897        try:
898            query('set client_encoding=iso_8859_5')
899            query("select 'ЌОр'").getresult()[0][0] == 'ЌОр'
900        except (pg.DataError, pg.NotSupportedError):
901            self.skipTest("database does not support cyrillic")
902        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
903            ('Hello', u'wörld'))
904        r = query("select $1||', '||$2||'!'",
905            ('Hello', u'ЌОр')).getresult()
906        if unicode_strings:
907            self.assertEqual(r, [('Hello, ЌОр!',)])
908        else:
909            self.assertEqual(r, [(u'Hello, ЌОр!'.encode('cyrillic'),)])
910        query('set client_encoding=sql_ascii')
911        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
912            ('Hello', u'ЌОр!'))
913
914    def testQueryWithMixedParams(self):
915        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
916            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
917        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
918            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
919
920    def testQueryWithDuplicateParams(self):
921        self.assertRaises(pg.ProgrammingError,
922            self.c.query, "select $1+$1", (1,))
923        self.assertRaises(pg.ProgrammingError,
924            self.c.query, "select $1+$1", (1, 2))
925
926    def testQueryWithZeroParams(self):
927        self.assertEqual(self.c.query("select 1+1", []
928            ).getresult(), [(2,)])
929
930    def testQueryWithGarbage(self):
931        garbage = r"'\{}+()-#[]oo324"
932        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
933            ).dictresult(), [{'garbage': garbage}])
934
935
936class TestPreparedQueries(unittest.TestCase):
937    """Test prepared queries via a basic pg connection."""
938
939    def setUp(self):
940        self.c = connect()
941        self.c.query('set client_encoding=utf8')
942
943    def tearDown(self):
944        self.c.close()
945
946    def testEmptyPreparedStatement(self):
947        self.c.prepare('', '')
948        self.assertRaises(ValueError, self.c.query_prepared, '')
949
950    def testInvalidPreparedStatement(self):
951        self.assertRaises(pg.ProgrammingError, self.c.prepare, '', 'bad')
952
953    def testDuplicatePreparedStatement(self):
954        self.assertIsNone(self.c.prepare('q', 'select 1'))
955        self.assertRaises(pg.ProgrammingError, self.c.prepare, 'q', 'select 2')
956
957    def testNonExistentPreparedStatement(self):
958        self.assertRaises(pg.OperationalError,
959            self.c.query_prepared, 'does-not-exist')
960
961    def testAnonymousQueryWithoutParams(self):
962        self.assertIsNone(self.c.prepare('', "select 'anon'"))
963        self.assertEqual(self.c.query_prepared('').getresult(), [('anon',)])
964
965    def testNamedQueryWithoutParams(self):
966        self.assertIsNone(self.c.prepare('hello', "select 'world'"))
967        self.assertEqual(self.c.query_prepared('hello').getresult(),
968            [('world',)])
969
970    def testMultipleNamedQueriesWithoutParams(self):
971        self.assertIsNone(self.c.prepare('query17', "select 17"))
972        self.assertIsNone(self.c.prepare('query42', "select 42"))
973        self.assertEqual(self.c.query_prepared('query17').getresult(), [(17,)])
974        self.assertEqual(self.c.query_prepared('query42').getresult(), [(42,)])
975
976    def testAnonymousQueryWithParams(self):
977        self.assertIsNone(self.c.prepare('', "select $1 || ', ' || $2"))
978        self.assertEqual(
979            self.c.query_prepared('', ['hello', 'world']).getresult(),
980            [('hello, world',)])
981        self.assertIsNone(self.c.prepare('', "select 1+ $1 + $2 + $3"))
982        self.assertEqual(
983            self.c.query_prepared('', [17, -5, 29]).getresult(), [(42,)])
984
985    def testMultipleNamedQueriesWithParams(self):
986        self.assertIsNone(self.c.prepare('q1', "select $1 || '!'"))
987        self.assertIsNone(self.c.prepare('q2', "select $1 || '-' || $2"))
988        self.assertEqual(self.c.query_prepared('q1', ['hello']).getresult(),
989            [('hello!',)])
990        self.assertEqual(self.c.query_prepared('q2', ['he', 'lo']).getresult(),
991            [('he-lo',)])
992
993    def testDescribeNonExistentQuery(self):
994        self.assertRaises(pg.OperationalError,
995            self.c.describe_prepared, 'does-not-exist')
996
997    def testDescribeAnonymousQuery(self):
998        self.c.prepare('', "select 1::int, 'a'::char")
999        r = self.c.describe_prepared('')
1000        self.assertEqual(r.listfields(), ('int4', 'bpchar'))
1001
1002    def testDescribeNamedQuery(self):
1003        self.c.prepare('myquery', "select 1 as first, 2 as second")
1004        r = self.c.describe_prepared('myquery')
1005        self.assertEqual(r.listfields(), ('first', 'second'))
1006
1007    def testDescribeMultipleNamedQueries(self):
1008        self.c.prepare('query1', "select 1::int")
1009        self.c.prepare('query2', "select 1::int, 2::int")
1010        r = self.c.describe_prepared('query1')
1011        self.assertEqual(r.listfields(), ('int4',))
1012        r = self.c.describe_prepared('query2')
1013        self.assertEqual(r.listfields(), ('int4', 'int4'))
1014
1015
1016class TestQueryResultTypes(unittest.TestCase):
1017    """Test proper result types via a basic pg connection."""
1018
1019    def setUp(self):
1020        self.c = connect()
1021        self.c.query('set client_encoding=utf8')
1022        self.c.query("set datestyle='ISO,YMD'")
1023        self.c.query("set timezone='UTC'")
1024
1025    def tearDown(self):
1026        self.c.close()
1027
1028    def assert_proper_cast(self, value, pgtype, pytype):
1029        q = 'select $1::%s' % (pgtype,)
1030        try:
1031            r = self.c.query(q, (value,)).getresult()[0][0]
1032        except pg.ProgrammingError:
1033            if pgtype in ('json', 'jsonb'):
1034                self.skipTest('database does not support json')
1035        self.assertIsInstance(r, pytype)
1036        if isinstance(value, str):
1037            if not value or ' ' in value or '{' in value:
1038                value = '"%s"' % value
1039        value = '{%s}' % value
1040        r = self.c.query(q + '[]', (value,)).getresult()[0][0]
1041        if pgtype.startswith(('date', 'time', 'interval')):
1042            # arrays of these are casted by the DB wrapper only
1043            self.assertEqual(r, value)
1044        else:
1045            self.assertIsInstance(r, list)
1046            self.assertEqual(len(r), 1)
1047            self.assertIsInstance(r[0], pytype)
1048
1049    def testInt(self):
1050        self.assert_proper_cast(0, 'int', int)
1051        self.assert_proper_cast(0, 'smallint', int)
1052        self.assert_proper_cast(0, 'oid', int)
1053        self.assert_proper_cast(0, 'cid', int)
1054        self.assert_proper_cast(0, 'xid', int)
1055
1056    def testLong(self):
1057        self.assert_proper_cast(0, 'bigint', long)
1058
1059    def testFloat(self):
1060        self.assert_proper_cast(0, 'float', float)
1061        self.assert_proper_cast(0, 'real', float)
1062        self.assert_proper_cast(0, 'double', float)
1063        self.assert_proper_cast(0, 'double precision', float)
1064        self.assert_proper_cast('infinity', 'float', float)
1065
1066    def testFloat(self):
1067        decimal = pg.get_decimal()
1068        self.assert_proper_cast(decimal(0), 'numeric', decimal)
1069        self.assert_proper_cast(decimal(0), 'decimal', decimal)
1070
1071    def testMoney(self):
1072        decimal = pg.get_decimal()
1073        self.assert_proper_cast(decimal('0'), 'money', decimal)
1074
1075    def testBool(self):
1076        bool_type = bool if pg.get_bool() else str
1077        self.assert_proper_cast('f', 'bool', bool_type)
1078
1079    def testDate(self):
1080        self.assert_proper_cast('1956-01-31', 'date', str)
1081        self.assert_proper_cast('10:20:30', 'interval', str)
1082        self.assert_proper_cast('08:42:15', 'time', str)
1083        self.assert_proper_cast('08:42:15+00', 'timetz', str)
1084        self.assert_proper_cast('1956-01-31 08:42:15', 'timestamp', str)
1085        self.assert_proper_cast('1956-01-31 08:42:15+00', 'timestamptz', str)
1086
1087    def testText(self):
1088        self.assert_proper_cast('', 'text', str)
1089        self.assert_proper_cast('', 'char', str)
1090        self.assert_proper_cast('', 'bpchar', str)
1091        self.assert_proper_cast('', 'varchar', str)
1092
1093    def testBytea(self):
1094        self.assert_proper_cast('', 'bytea', bytes)
1095
1096    def testJson(self):
1097        self.assert_proper_cast('{}', 'json', dict)
1098
1099
1100class TestInserttable(unittest.TestCase):
1101    """Test inserttable method."""
1102
1103    cls_set_up = False
1104
1105    @classmethod
1106    def setUpClass(cls):
1107        c = connect()
1108        c.query("drop table if exists test cascade")
1109        c.query("create table test ("
1110            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
1111            "d numeric, f4 real, f8 double precision, m money,"
1112            "c char(1), v4 varchar(4), c4 char(4), t text)")
1113        # Check whether the test database uses SQL_ASCII - this means
1114        # that it does not consider encoding when calculating lengths.
1115        c.query("set client_encoding=utf8")
1116        try:
1117            c.query("select 'À'")
1118        except (pg.DataError, pg.NotSupportedError):
1119            cls.has_encoding = False
1120        else:
1121            cls.has_encoding = c.query(
1122                "select length('À') - length('a')").getresult()[0][0] == 0
1123        c.close()
1124        cls.cls_set_up = True
1125
1126    @classmethod
1127    def tearDownClass(cls):
1128        c = connect()
1129        c.query("drop table test cascade")
1130        c.close()
1131
1132    def setUp(self):
1133        self.assertTrue(self.cls_set_up)
1134        self.c = connect()
1135        self.c.query("set client_encoding=utf8")
1136        self.c.query("set datestyle='ISO,YMD'")
1137        self.c.query("set lc_monetary='C'")
1138
1139    def tearDown(self):
1140        self.c.query("truncate table test")
1141        self.c.close()
1142
1143    data = [
1144        (-1, -1, long(-1), True, '1492-10-12', '08:30:00',
1145            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
1146        (0, 0, long(0), False, '1607-04-14', '09:00:00',
1147            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
1148        (1, 1, long(1), True, '1801-03-04', '03:45:00',
1149            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
1150        (2, 2, long(2), False, '1903-12-17', '11:22:00',
1151            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
1152
1153    @classmethod
1154    def db_len(cls, s, encoding):
1155        if cls.has_encoding:
1156            s = s if isinstance(s, unicode) else s.decode(encoding)
1157        else:
1158            s = s.encode(encoding) if isinstance(s, unicode) else s
1159        return len(s)
1160
1161    def get_back(self, encoding='utf-8'):
1162        """Convert boolean and decimal values back."""
1163        data = []
1164        for row in self.c.query("select * from test order by 1").getresult():
1165            self.assertIsInstance(row, tuple)
1166            row = list(row)
1167            if row[0] is not None:  # smallint
1168                self.assertIsInstance(row[0], int)
1169            if row[1] is not None:  # integer
1170                self.assertIsInstance(row[1], int)
1171            if row[2] is not None:  # bigint
1172                self.assertIsInstance(row[2], long)
1173            if row[3] is not None:  # boolean
1174                self.assertIsInstance(row[3], bool)
1175            if row[4] is not None:  # date
1176                self.assertIsInstance(row[4], str)
1177                self.assertTrue(row[4].replace('-', '').isdigit())
1178            if row[5] is not None:  # time
1179                self.assertIsInstance(row[5], str)
1180                self.assertTrue(row[5].replace(':', '').isdigit())
1181            if row[6] is not None:  # numeric
1182                self.assertIsInstance(row[6], Decimal)
1183                row[6] = float(row[6])
1184            if row[7] is not None:  # real
1185                self.assertIsInstance(row[7], float)
1186            if row[8] is not None:  # double precision
1187                self.assertIsInstance(row[8], float)
1188                row[8] = float(row[8])
1189            if row[9] is not None:  # money
1190                self.assertIsInstance(row[9], Decimal)
1191                row[9] = str(float(row[9]))
1192            if row[10] is not None:  # char(1)
1193                self.assertIsInstance(row[10], str)
1194                self.assertEqual(self.db_len(row[10], encoding), 1)
1195            if row[11] is not None:  # varchar(4)
1196                self.assertIsInstance(row[11], str)
1197                self.assertLessEqual(self.db_len(row[11], encoding), 4)
1198            if row[12] is not None:  # char(4)
1199                self.assertIsInstance(row[12], str)
1200                self.assertEqual(self.db_len(row[12], encoding), 4)
1201                row[12] = row[12].rstrip()
1202            if row[13] is not None:  # text
1203                self.assertIsInstance(row[13], str)
1204            row = tuple(row)
1205            data.append(row)
1206        return data
1207
1208    def testInserttable1Row(self):
1209        data = self.data[2:3]
1210        self.c.inserttable('test', data)
1211        self.assertEqual(self.get_back(), data)
1212
1213    def testInserttable4Rows(self):
1214        data = self.data
1215        self.c.inserttable('test', data)
1216        self.assertEqual(self.get_back(), data)
1217
1218    def testInserttableMultipleRows(self):
1219        num_rows = 100
1220        data = self.data[2:3] * num_rows
1221        self.c.inserttable('test', data)
1222        r = self.c.query("select count(*) from test").getresult()[0][0]
1223        self.assertEqual(r, num_rows)
1224
1225    def testInserttableMultipleCalls(self):
1226        num_rows = 10
1227        data = self.data[2:3]
1228        for _i in range(num_rows):
1229            self.c.inserttable('test', data)
1230        r = self.c.query("select count(*) from test").getresult()[0][0]
1231        self.assertEqual(r, num_rows)
1232
1233    def testInserttableNullValues(self):
1234        data = [(None,) * 14] * 100
1235        self.c.inserttable('test', data)
1236        self.assertEqual(self.get_back(), data)
1237
1238    def testInserttableMaxValues(self):
1239        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
1240            True, '2999-12-31', '11:59:59', 1e99,
1241            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
1242            "1", "1234", "1234", "1234" * 100)]
1243        self.c.inserttable('test', data)
1244        self.assertEqual(self.get_back(), data)
1245
1246    def testInserttableByteValues(self):
1247        try:
1248            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1249        except pg.DataError:
1250            self.skipTest("database does not support utf8")
1251        # non-ascii chars do not fit in char(1) when there is no encoding
1252        c = u'€' if self.has_encoding else u'$'
1253        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1254            0.0, 0.0, 0.0, u'0.0',
1255            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1256        row_bytes = tuple(s.encode('utf-8')
1257            if isinstance(s, unicode) else s for s in row_unicode)
1258        data = [row_bytes] * 2
1259        self.c.inserttable('test', data)
1260        if unicode_strings:
1261            data = [row_unicode] * 2
1262        self.assertEqual(self.get_back(), data)
1263
1264    def testInserttableUnicodeUtf8(self):
1265        try:
1266            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1267        except pg.DataError:
1268            self.skipTest("database does not support utf8")
1269        # non-ascii chars do not fit in char(1) when there is no encoding
1270        c = u'€' if self.has_encoding else u'$'
1271        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1272            0.0, 0.0, 0.0, u'0.0',
1273            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1274        data = [row_unicode] * 2
1275        self.c.inserttable('test', data)
1276        if not unicode_strings:
1277            row_bytes = tuple(s.encode('utf-8')
1278                if isinstance(s, unicode) else s for s in row_unicode)
1279            data = [row_bytes] * 2
1280        self.assertEqual(self.get_back(), data)
1281
1282    def testInserttableUnicodeLatin1(self):
1283
1284        try:
1285            self.c.query("set client_encoding=latin1")
1286            self.c.query("select 'Â¥'")
1287        except (pg.DataError, pg.NotSupportedError):
1288            self.skipTest("database does not support latin1")
1289        # non-ascii chars do not fit in char(1) when there is no encoding
1290        c = u'€' if self.has_encoding else u'$'
1291        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1292            0.0, 0.0, 0.0, u'0.0',
1293            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1294        data = [row_unicode]
1295        # cannot encode € sign with latin1 encoding
1296        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1297        row_unicode = tuple(s.replace(u'€', u'Â¥')
1298            if isinstance(s, unicode) else s for s in row_unicode)
1299        data = [row_unicode] * 2
1300        self.c.inserttable('test', data)
1301        if not unicode_strings:
1302            row_bytes = tuple(s.encode('latin1')
1303                if isinstance(s, unicode) else s for s in row_unicode)
1304            data = [row_bytes] * 2
1305        self.assertEqual(self.get_back('latin1'), data)
1306
1307    def testInserttableUnicodeLatin9(self):
1308        try:
1309            self.c.query("set client_encoding=latin9")
1310            self.c.query("select '€'")
1311        except (pg.DataError, pg.NotSupportedError):
1312            self.skipTest("database does not support latin9")
1313            return
1314        # non-ascii chars do not fit in char(1) when there is no encoding
1315        c = u'€' if self.has_encoding else u'$'
1316        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1317            0.0, 0.0, 0.0, u'0.0',
1318            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1319        data = [row_unicode] * 2
1320        self.c.inserttable('test', data)
1321        if not unicode_strings:
1322            row_bytes = tuple(s.encode('latin9')
1323                if isinstance(s, unicode) else s for s in row_unicode)
1324            data = [row_bytes] * 2
1325        self.assertEqual(self.get_back('latin9'), data)
1326
1327    def testInserttableNoEncoding(self):
1328        self.c.query("set client_encoding=sql_ascii")
1329        # non-ascii chars do not fit in char(1) when there is no encoding
1330        c = u'€' if self.has_encoding else u'$'
1331        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1332            0.0, 0.0, 0.0, u'0.0',
1333            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1334        data = [row_unicode]
1335        # cannot encode non-ascii unicode without a specific encoding
1336        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1337
1338
1339class TestDirectSocketAccess(unittest.TestCase):
1340    """Test copy command with direct socket access."""
1341
1342    cls_set_up = False
1343
1344    @classmethod
1345    def setUpClass(cls):
1346        c = connect()
1347        c.query("drop table if exists test cascade")
1348        c.query("create table test (i int, v varchar(16))")
1349        c.close()
1350        cls.cls_set_up = True
1351
1352    @classmethod
1353    def tearDownClass(cls):
1354        c = connect()
1355        c.query("drop table test cascade")
1356        c.close()
1357
1358    def setUp(self):
1359        self.assertTrue(self.cls_set_up)
1360        self.c = connect()
1361        self.c.query("set client_encoding=utf8")
1362
1363    def tearDown(self):
1364        self.c.query("truncate table test")
1365        self.c.close()
1366
1367    def testPutline(self):
1368        putline = self.c.putline
1369        query = self.c.query
1370        data = list(enumerate("apple pear plum cherry banana".split()))
1371        query("copy test from stdin")
1372        try:
1373            for i, v in data:
1374                putline("%d\t%s\n" % (i, v))
1375            putline("\\.\n")
1376        finally:
1377            self.c.endcopy()
1378        r = query("select * from test").getresult()
1379        self.assertEqual(r, data)
1380
1381    def testPutlineBytesAndUnicode(self):
1382        putline = self.c.putline
1383        query = self.c.query
1384        try:
1385            query("select 'kÀse+wÃŒrstel'")
1386        except (pg.DataError, pg.NotSupportedError):
1387            self.skipTest('database does not support utf8')
1388        query("copy test from stdin")
1389        try:
1390            putline(u"47\tkÀse\n".encode('utf8'))
1391            putline("35\twÃŒrstel\n")
1392            putline(b"\\.\n")
1393        finally:
1394            self.c.endcopy()
1395        r = query("select * from test").getresult()
1396        self.assertEqual(r, [(47, 'kÀse'), (35, 'wÃŒrstel')])
1397
1398    def testGetline(self):
1399        getline = self.c.getline
1400        query = self.c.query
1401        data = list(enumerate("apple banana pear plum strawberry".split()))
1402        n = len(data)
1403        self.c.inserttable('test', data)
1404        query("copy test to stdout")
1405        try:
1406            for i in range(n + 2):
1407                v = getline()
1408                if i < n:
1409                    self.assertEqual(v, '%d\t%s' % data[i])
1410                elif i == n:
1411                    self.assertEqual(v, '\\.')
1412                else:
1413                    self.assertIsNone(v)
1414        finally:
1415            try:
1416                self.c.endcopy()
1417            except IOError:
1418                pass
1419
1420    def testGetlineBytesAndUnicode(self):
1421        getline = self.c.getline
1422        query = self.c.query
1423        try:
1424            query("select 'kÀse+wÃŒrstel'")
1425        except (pg.DataError, pg.NotSupportedError):
1426            self.skipTest('database does not support utf8')
1427        data = [(54, u'kÀse'.encode('utf8')), (73, u'wÃŒrstel')]
1428        self.c.inserttable('test', data)
1429        query("copy test to stdout")
1430        try:
1431            v = getline()
1432            self.assertIsInstance(v, str)
1433            self.assertEqual(v, '54\tkÀse')
1434            v = getline()
1435            self.assertIsInstance(v, str)
1436            self.assertEqual(v, '73\twÃŒrstel')
1437            self.assertEqual(getline(), '\\.')
1438            self.assertIsNone(getline())
1439        finally:
1440            try:
1441                self.c.endcopy()
1442            except IOError:
1443                pass
1444
1445    def testParameterChecks(self):
1446        self.assertRaises(TypeError, self.c.putline)
1447        self.assertRaises(TypeError, self.c.getline, 'invalid')
1448        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
1449
1450
1451class TestNotificatons(unittest.TestCase):
1452    """Test notification support."""
1453
1454    def setUp(self):
1455        self.c = connect()
1456
1457    def tearDown(self):
1458        self.doCleanups()
1459        self.c.close()
1460
1461    def testGetNotify(self):
1462        getnotify = self.c.getnotify
1463        query = self.c.query
1464        self.assertIsNone(getnotify())
1465        query('listen test_notify')
1466        try:
1467            self.assertIsNone(self.c.getnotify())
1468            query("notify test_notify")
1469            r = getnotify()
1470            self.assertIsInstance(r, tuple)
1471            self.assertEqual(len(r), 3)
1472            self.assertIsInstance(r[0], str)
1473            self.assertIsInstance(r[1], int)
1474            self.assertIsInstance(r[2], str)
1475            self.assertEqual(r[0], 'test_notify')
1476            self.assertEqual(r[2], '')
1477            self.assertIsNone(self.c.getnotify())
1478            query("notify test_notify, 'test_payload'")
1479            r = getnotify()
1480            self.assertTrue(isinstance(r, tuple))
1481            self.assertEqual(len(r), 3)
1482            self.assertIsInstance(r[0], str)
1483            self.assertIsInstance(r[1], int)
1484            self.assertIsInstance(r[2], str)
1485            self.assertEqual(r[0], 'test_notify')
1486            self.assertEqual(r[2], 'test_payload')
1487            self.assertIsNone(getnotify())
1488        finally:
1489            query('unlisten test_notify')
1490
1491    def testGetNoticeReceiver(self):
1492        self.assertIsNone(self.c.get_notice_receiver())
1493
1494    def testSetNoticeReceiver(self):
1495        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
1496        self.assertRaises(TypeError, self.c.set_notice_receiver, 'invalid')
1497        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
1498        self.assertIsNone(self.c.set_notice_receiver(None))
1499
1500    def testSetAndGetNoticeReceiver(self):
1501        r = lambda notice: None
1502        self.assertIsNone(self.c.set_notice_receiver(r))
1503        self.assertIs(self.c.get_notice_receiver(), r)
1504        self.assertIsNone(self.c.set_notice_receiver(None))
1505        self.assertIsNone(self.c.get_notice_receiver())
1506
1507    def testNoticeReceiver(self):
1508        self.addCleanup(self.c.query, 'drop function bilbo_notice();')
1509        self.c.query('''create function bilbo_notice() returns void AS $$
1510            begin
1511                raise warning 'Bilbo was here!';
1512            end;
1513            $$ language plpgsql''')
1514        received = {}
1515
1516        def notice_receiver(notice):
1517            for attr in dir(notice):
1518                if attr.startswith('__'):
1519                    continue
1520                value = getattr(notice, attr)
1521                if isinstance(value, str):
1522                    value = value.replace('WARNUNG', 'WARNING')
1523                received[attr] = value
1524
1525        self.c.set_notice_receiver(notice_receiver)
1526        self.c.query('select bilbo_notice()')
1527        self.assertEqual(received, dict(
1528            pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
1529            severity='WARNING', primary='Bilbo was here!',
1530            detail=None, hint=None))
1531
1532
1533class TestConfigFunctions(unittest.TestCase):
1534    """Test the functions for changing default settings.
1535
1536    To test the effect of most of these functions, we need a database
1537    connection.  That's why they are covered in this test module.
1538    """
1539
1540    def setUp(self):
1541        self.c = connect()
1542        self.c.query("set client_encoding=utf8")
1543        self.c.query('set bytea_output=hex')
1544        self.c.query("set lc_monetary='C'")
1545
1546    def tearDown(self):
1547        self.c.close()
1548
1549    def testGetDecimalPoint(self):
1550        point = pg.get_decimal_point()
1551        # error if a parameter is passed
1552        self.assertRaises(TypeError, pg.get_decimal_point, point)
1553        self.assertIsInstance(point, str)
1554        self.assertEqual(point, '.')  # the default setting
1555        pg.set_decimal_point(',')
1556        try:
1557            r = pg.get_decimal_point()
1558        finally:
1559            pg.set_decimal_point(point)
1560        self.assertIsInstance(r, str)
1561        self.assertEqual(r, ',')
1562        pg.set_decimal_point("'")
1563        try:
1564            r = pg.get_decimal_point()
1565        finally:
1566            pg.set_decimal_point(point)
1567        self.assertIsInstance(r, str)
1568        self.assertEqual(r, "'")
1569        pg.set_decimal_point('')
1570        try:
1571            r = pg.get_decimal_point()
1572        finally:
1573            pg.set_decimal_point(point)
1574        self.assertIsNone(r)
1575        pg.set_decimal_point(None)
1576        try:
1577            r = pg.get_decimal_point()
1578        finally:
1579            pg.set_decimal_point(point)
1580        self.assertIsNone(r)
1581
1582    def testSetDecimalPoint(self):
1583        d = pg.Decimal
1584        point = pg.get_decimal_point()
1585        self.assertRaises(TypeError, pg.set_decimal_point)
1586        # error if decimal point is not a string
1587        self.assertRaises(TypeError, pg.set_decimal_point, 0)
1588        # error if more than one decimal point passed
1589        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
1590        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
1591        # error if decimal point is not a punctuation character
1592        self.assertRaises(TypeError, pg.set_decimal_point, '0')
1593        query = self.c.query
1594        # check that money values are interpreted as decimal values
1595        # only if decimal_point is set, and that the result is correct
1596        # only if it is set suitable for the current lc_monetary setting
1597        select_money = "select '34.25'::money"
1598        proper_money = d('34.25')
1599        bad_money = d('3425')
1600        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
1601        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
1602        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
1603        de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25',
1604            'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM')
1605        # first try with English localization (using the point)
1606        for lc in en_locales:
1607            try:
1608                query("set lc_monetary='%s'" % lc)
1609            except pg.DataError:
1610                pass
1611            else:
1612                break
1613        else:
1614            self.skipTest("cannot set English money locale")
1615        try:
1616            r = query(select_money)
1617        except pg.DataError:
1618            # this can happen if the currency signs cannot be
1619            # converted using the encoding of the test database
1620            self.skipTest("database does not support English money")
1621        pg.set_decimal_point(None)
1622        try:
1623            r = r.getresult()[0][0]
1624        finally:
1625            pg.set_decimal_point(point)
1626        self.assertIsInstance(r, str)
1627        self.assertIn(r, en_money)
1628        r = query(select_money)
1629        pg.set_decimal_point('')
1630        try:
1631            r = r.getresult()[0][0]
1632        finally:
1633            pg.set_decimal_point(point)
1634        self.assertIsInstance(r, str)
1635        self.assertIn(r, en_money)
1636        r = query(select_money)
1637        pg.set_decimal_point('.')
1638        try:
1639            r = r.getresult()[0][0]
1640        finally:
1641            pg.set_decimal_point(point)
1642        self.assertIsInstance(r, d)
1643        self.assertEqual(r, proper_money)
1644        r = query(select_money)
1645        pg.set_decimal_point(',')
1646        try:
1647            r = r.getresult()[0][0]
1648        finally:
1649            pg.set_decimal_point(point)
1650        self.assertIsInstance(r, d)
1651        self.assertEqual(r, bad_money)
1652        r = query(select_money)
1653        pg.set_decimal_point("'")
1654        try:
1655            r = r.getresult()[0][0]
1656        finally:
1657            pg.set_decimal_point(point)
1658        self.assertIsInstance(r, d)
1659        self.assertEqual(r, bad_money)
1660        # then try with German localization (using the comma)
1661        for lc in de_locales:
1662            try:
1663                query("set lc_monetary='%s'" % lc)
1664            except pg.DataError:
1665                pass
1666            else:
1667                break
1668        else:
1669            self.skipTest("cannot set German money locale")
1670        select_money = select_money.replace('.', ',')
1671        try:
1672            r = query(select_money)
1673        except pg.DataError:
1674            self.skipTest("database does not support English money")
1675        pg.set_decimal_point(None)
1676        try:
1677            r = r.getresult()[0][0]
1678        finally:
1679            pg.set_decimal_point(point)
1680        self.assertIsInstance(r, str)
1681        self.assertIn(r, de_money)
1682        r = query(select_money)
1683        pg.set_decimal_point('')
1684        try:
1685            r = r.getresult()[0][0]
1686        finally:
1687            pg.set_decimal_point(point)
1688        self.assertIsInstance(r, str)
1689        self.assertIn(r, de_money)
1690        r = query(select_money)
1691        pg.set_decimal_point(',')
1692        try:
1693            r = r.getresult()[0][0]
1694        finally:
1695            pg.set_decimal_point(point)
1696        self.assertIsInstance(r, d)
1697        self.assertEqual(r, proper_money)
1698        r = query(select_money)
1699        pg.set_decimal_point('.')
1700        try:
1701            r = r.getresult()[0][0]
1702        finally:
1703            pg.set_decimal_point(point)
1704        self.assertEqual(r, bad_money)
1705        r = query(select_money)
1706        pg.set_decimal_point("'")
1707        try:
1708            r = r.getresult()[0][0]
1709        finally:
1710            pg.set_decimal_point(point)
1711        self.assertEqual(r, bad_money)
1712
1713    def testGetDecimal(self):
1714        decimal_class = pg.get_decimal()
1715        # error if a parameter is passed
1716        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
1717        self.assertIs(decimal_class, pg.Decimal)  # the default setting
1718        pg.set_decimal(int)
1719        try:
1720            r = pg.get_decimal()
1721        finally:
1722            pg.set_decimal(decimal_class)
1723        self.assertIs(r, int)
1724        r = pg.get_decimal()
1725        self.assertIs(r, decimal_class)
1726
1727    def testSetDecimal(self):
1728        decimal_class = pg.get_decimal()
1729        # error if no parameter is passed
1730        self.assertRaises(TypeError, pg.set_decimal)
1731        query = self.c.query
1732        try:
1733            r = query("select 3425::numeric")
1734        except pg.DatabaseError:
1735            self.skipTest('database does not support numeric')
1736        r = r.getresult()[0][0]
1737        self.assertIsInstance(r, decimal_class)
1738        self.assertEqual(r, decimal_class('3425'))
1739        r = query("select 3425::numeric")
1740        pg.set_decimal(int)
1741        try:
1742            r = r.getresult()[0][0]
1743        finally:
1744            pg.set_decimal(decimal_class)
1745        self.assertNotIsInstance(r, decimal_class)
1746        self.assertIsInstance(r, int)
1747        self.assertEqual(r, int(3425))
1748
1749    def testGetBool(self):
1750        use_bool = pg.get_bool()
1751        # error if a parameter is passed
1752        self.assertRaises(TypeError, pg.get_bool, use_bool)
1753        self.assertIsInstance(use_bool, bool)
1754        self.assertIs(use_bool, True)  # the default setting
1755        pg.set_bool(False)
1756        try:
1757            r = pg.get_bool()
1758        finally:
1759            pg.set_bool(use_bool)
1760        self.assertIsInstance(r, bool)
1761        self.assertIs(r, False)
1762        pg.set_bool(True)
1763        try:
1764            r = pg.get_bool()
1765        finally:
1766            pg.set_bool(use_bool)
1767        self.assertIsInstance(r, bool)
1768        self.assertIs(r, True)
1769        pg.set_bool(0)
1770        try:
1771            r = pg.get_bool()
1772        finally:
1773            pg.set_bool(use_bool)
1774        self.assertIsInstance(r, bool)
1775        self.assertIs(r, False)
1776        pg.set_bool(1)
1777        try:
1778            r = pg.get_bool()
1779        finally:
1780            pg.set_bool(use_bool)
1781        self.assertIsInstance(r, bool)
1782        self.assertIs(r, True)
1783
1784    def testSetBool(self):
1785        use_bool = pg.get_bool()
1786        # error if no parameter is passed
1787        self.assertRaises(TypeError, pg.set_bool)
1788        query = self.c.query
1789        try:
1790            r = query("select true::bool")
1791        except pg.ProgrammingError:
1792            self.skipTest('database does not support bool')
1793        r = r.getresult()[0][0]
1794        self.assertIsInstance(r, bool)
1795        self.assertEqual(r, True)
1796        r = query("select true::bool")
1797        pg.set_bool(False)
1798        try:
1799            r = r.getresult()[0][0]
1800        finally:
1801            pg.set_bool(use_bool)
1802        self.assertIsInstance(r, str)
1803        self.assertIs(r, 't')
1804        r = query("select true::bool")
1805        pg.set_bool(True)
1806        try:
1807            r = r.getresult()[0][0]
1808        finally:
1809            pg.set_bool(use_bool)
1810        self.assertIsInstance(r, bool)
1811        self.assertIs(r, True)
1812
1813    def testGetByteEscaped(self):
1814        bytea_escaped = pg.get_bytea_escaped()
1815        # error if a parameter is passed
1816        self.assertRaises(TypeError, pg.get_bytea_escaped, bytea_escaped)
1817        self.assertIsInstance(bytea_escaped, bool)
1818        self.assertIs(bytea_escaped, False)  # the default setting
1819        pg.set_bytea_escaped(True)
1820        try:
1821            r = pg.get_bytea_escaped()
1822        finally:
1823            pg.set_bytea_escaped(bytea_escaped)
1824        self.assertIsInstance(r, bool)
1825        self.assertIs(r, True)
1826        pg.set_bytea_escaped(False)
1827        try:
1828            r = pg.get_bytea_escaped()
1829        finally:
1830            pg.set_bytea_escaped(bytea_escaped)
1831        self.assertIsInstance(r, bool)
1832        self.assertIs(r, False)
1833        pg.set_bytea_escaped(1)
1834        try:
1835            r = pg.get_bytea_escaped()
1836        finally:
1837            pg.set_bytea_escaped(bytea_escaped)
1838        self.assertIsInstance(r, bool)
1839        self.assertIs(r, True)
1840        pg.set_bytea_escaped(0)
1841        try:
1842            r = pg.get_bytea_escaped()
1843        finally:
1844            pg.set_bytea_escaped(bytea_escaped)
1845        self.assertIsInstance(r, bool)
1846        self.assertIs(r, False)
1847
1848    def testSetByteaEscaped(self):
1849        bytea_escaped = pg.get_bytea_escaped()
1850        # error if no parameter is passed
1851        self.assertRaises(TypeError, pg.set_bytea_escaped)
1852        query = self.c.query
1853        try:
1854            r = query("select 'data'::bytea")
1855        except pg.ProgrammingError:
1856            self.skipTest('database does not support bytea')
1857        r = r.getresult()[0][0]
1858        self.assertIsInstance(r, bytes)
1859        self.assertEqual(r, b'data')
1860        r = query("select 'data'::bytea")
1861        pg.set_bytea_escaped(True)
1862        try:
1863            r = r.getresult()[0][0]
1864        finally:
1865            pg.set_bytea_escaped(bytea_escaped)
1866        self.assertIsInstance(r, str)
1867        self.assertEqual(r, '\\x64617461')
1868        r = query("select 'data'::bytea")
1869        pg.set_bytea_escaped(False)
1870        try:
1871            r = r.getresult()[0][0]
1872        finally:
1873            pg.set_bytea_escaped(bytea_escaped)
1874        self.assertIsInstance(r, bytes)
1875        self.assertEqual(r, b'data')
1876
1877    def testGetNamedresult(self):
1878        namedresult = pg.get_namedresult()
1879        # error if a parameter is passed
1880        self.assertRaises(TypeError, pg.get_namedresult, namedresult)
1881        self.assertIs(namedresult, pg._namedresult)  # the default setting
1882
1883    def testSetNamedresult(self):
1884        namedresult = pg.get_namedresult()
1885        self.assertTrue(callable(namedresult))
1886
1887        query = self.c.query
1888
1889        r = query("select 1 as x, 2 as y").namedresult()[0]
1890        self.assertIsInstance(r, tuple)
1891        self.assertEqual(r, (1, 2))
1892        self.assertIsNot(type(r), tuple)
1893        self.assertEqual(r._fields, ('x', 'y'))
1894        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1895        self.assertEqual(r.__class__.__name__, 'Row')
1896
1897        def listresult(q):
1898            return [list(row) for row in q.getresult()]
1899
1900        pg.set_namedresult(listresult)
1901        try:
1902            r = pg.get_namedresult()
1903            self.assertIs(r, listresult)
1904            r = query("select 1 as x, 2 as y").namedresult()[0]
1905            self.assertIsInstance(r, list)
1906            self.assertEqual(r, [1, 2])
1907            self.assertIsNot(type(r), tuple)
1908            self.assertFalse(hasattr(r, '_fields'))
1909            self.assertNotEqual(r.__class__.__name__, 'Row')
1910        finally:
1911            pg.set_namedresult(namedresult)
1912
1913        r = pg.get_namedresult()
1914        self.assertIs(r, namedresult)
1915
1916    def testSetRowFactorySize(self):
1917        try:
1918            from functools import lru_cache
1919        except ImportError:  # Python < 3.2
1920            lru_cache = None
1921        queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc']
1922        query = self.c.query
1923        for maxsize in (None, 0, 1, 2, 3, 10, 1024):
1924            pg.set_row_factory_size(maxsize)
1925            for i in range(3):
1926                for q in queries:
1927                    r = query(q).namedresult()[0]
1928                    if q.endswith('abc'):
1929                        self.assertEqual(r, (123,))
1930                        self.assertEqual(r._fields, ('abc',))
1931                    else:
1932                        self.assertEqual(r, (1, 2, 3))
1933                        self.assertEqual(r._fields, ('a', 'b', 'c'))
1934            if lru_cache:
1935                info = pg._row_factory.cache_info()
1936                self.assertEqual(info.maxsize, maxsize)
1937                self.assertEqual(info.hits + info.misses, 6)
1938                self.assertEqual(info.hits,
1939                    0 if maxsize is not None and maxsize < 2 else 4)
1940
1941
1942class TestStandaloneEscapeFunctions(unittest.TestCase):
1943    """Test pg escape functions.
1944
1945    The libpq interface memorizes some parameters of the last opened
1946    connection that influence the result of these functions.  Therefore
1947    we need to open a connection with fixed parameters prior to testing
1948    in order to ensure that the tests always run under the same conditions.
1949    That's why these tests are included in this test module.
1950    """
1951
1952    cls_set_up = False
1953
1954    @classmethod
1955    def setUpClass(cls):
1956        db = connect()
1957        query = db.query
1958        query('set client_encoding=sql_ascii')
1959        query('set standard_conforming_strings=off')
1960        try:
1961            query('set bytea_output=escape')
1962        except pg.ProgrammingError:
1963            if db.server_version >= 90000:
1964                raise  # ignore for older server versions
1965        db.close()
1966        cls.cls_set_up = True
1967
1968    def testEscapeString(self):
1969        self.assertTrue(self.cls_set_up)
1970        f = pg.escape_string
1971        r = f(b'plain')
1972        self.assertIsInstance(r, bytes)
1973        self.assertEqual(r, b'plain')
1974        r = f(u'plain')
1975        self.assertIsInstance(r, unicode)
1976        self.assertEqual(r, u'plain')
1977        r = f(u"das is' kÀse".encode('utf-8'))
1978        self.assertIsInstance(r, bytes)
1979        self.assertEqual(r, u"das is'' kÀse".encode('utf-8'))
1980        r = f(u"that's cheesy")
1981        self.assertIsInstance(r, unicode)
1982        self.assertEqual(r, u"that''s cheesy")
1983        r = f(r"It's bad to have a \ inside.")
1984        self.assertEqual(r, r"It''s bad to have a \\ inside.")
1985
1986    def testEscapeBytea(self):
1987        self.assertTrue(self.cls_set_up)
1988        f = pg.escape_bytea
1989        r = f(b'plain')
1990        self.assertIsInstance(r, bytes)
1991        self.assertEqual(r, b'plain')
1992        r = f(u'plain')
1993        self.assertIsInstance(r, unicode)
1994        self.assertEqual(r, u'plain')
1995        r = f(u"das is' kÀse".encode('utf-8'))
1996        self.assertIsInstance(r, bytes)
1997        self.assertEqual(r, b"das is'' k\\\\303\\\\244se")
1998        r = f(u"that's cheesy")
1999        self.assertIsInstance(r, unicode)
2000        self.assertEqual(r, u"that''s cheesy")
2001        r = f(b'O\x00ps\xff!')
2002        self.assertEqual(r, b'O\\\\000ps\\\\377!')
2003
2004
2005if __name__ == '__main__':
2006    unittest.main()
Note: See TracBrowser for help on using the repository browser.