source: trunk/tests/test_classic_connection.py @ 841

Last change on this file since 841 was 823, checked in by cito, 4 years ago

Raise the proper subclasses of DatabaseError?

Particularly, we raise IntegrityError? instead of ProgrammingError? for
duplicate keys. This also makes PyGreSQL more useable with SQLAlchemy.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 66.9 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17import threading
18import time
19import os
20
21import pg  # the module under test
22
23from decimal import Decimal
24
25# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
26# get our information from that.  Otherwise we use the defaults.
27# These tests should be run with various PostgreSQL versions and databases
28# created with different encodings and locales.  Particularly, make sure the
29# tests are running against databases created with both SQL_ASCII and UTF8.
30dbname = 'unittest'
31dbhost = None
32dbport = 5432
33
34try:
35    from .LOCAL_PyGreSQL import *
36except (ImportError, ValueError):
37    try:
38        from LOCAL_PyGreSQL import *
39    except ImportError:
40        pass
41
42try:
43    long
44except NameError:  # Python >= 3.0
45    long = int
46
47try:
48    unicode
49except NameError:  # Python >= 3.0
50    unicode = str
51
52unicode_strings = str is not bytes
53
54windows = os.name == 'nt'
55
56# There is a known a bug in libpq under Windows which can cause
57# the interface to crash when calling PQhost():
58do_not_ask_for_host = windows
59do_not_ask_for_host_reason = 'libpq issue on Windows'
60
61
62def connect():
63    """Create a basic pg connection to the test database."""
64    connection = pg.connect(dbname, dbhost, dbport)
65    connection.query("set client_min_messages=warning")
66    return connection
67
68
69class TestCanConnect(unittest.TestCase):
70    """Test whether a basic connection to PostgreSQL is possible."""
71
72    def testCanConnect(self):
73        try:
74            connection = connect()
75        except pg.Error as error:
76            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
77        try:
78            connection.close()
79        except pg.Error:
80            self.fail('Cannot close the database connection')
81
82
83class TestConnectObject(unittest.TestCase):
84    """Test existence of basic pg connection methods."""
85
86    def setUp(self):
87        self.connection = connect()
88
89    def tearDown(self):
90        try:
91            self.connection.close()
92        except pg.InternalError:
93            pass
94
95    def is_method(self, attribute):
96        """Check if given attribute on the connection is a method."""
97        if do_not_ask_for_host and attribute == 'host':
98            return False
99        return callable(getattr(self.connection, attribute))
100
101    def testClassName(self):
102        self.assertEqual(self.connection.__class__.__name__, 'Connection')
103
104    def testModuleName(self):
105        self.assertEqual(self.connection.__class__.__module__, 'pg')
106
107    def testStr(self):
108        r = str(self.connection)
109        self.assertTrue(r.startswith('<pg.Connection object'), r)
110
111    def testRepr(self):
112        r = repr(self.connection)
113        self.assertTrue(r.startswith('<pg.Connection object'), r)
114
115    def testAllConnectAttributes(self):
116        attributes = '''db error host options port
117            protocol_version server_version status user'''.split()
118        connection_attributes = [a for a in dir(self.connection)
119            if not a.startswith('__') and not self.is_method(a)]
120        self.assertEqual(attributes, connection_attributes)
121
122    def testAllConnectMethods(self):
123        methods = '''cancel close date_format endcopy
124            escape_bytea escape_identifier escape_literal escape_string
125            fileno get_cast_hook get_notice_receiver getline getlo getnotify
126            inserttable locreate loimport parameter putline query reset
127            set_cast_hook set_notice_receiver source transaction'''.split()
128        connection_methods = [a for a in dir(self.connection)
129            if not a.startswith('__') and self.is_method(a)]
130        self.assertEqual(methods, connection_methods)
131
132    def testAttributeDb(self):
133        self.assertEqual(self.connection.db, dbname)
134
135    def testAttributeError(self):
136        error = self.connection.error
137        self.assertTrue(not error or 'krb5_' in error)
138
139    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
140    def testAttributeHost(self):
141        def_host = 'localhost'
142        self.assertIsInstance(self.connection.host, str)
143        self.assertEqual(self.connection.host, dbhost or def_host)
144
145    def testAttributeOptions(self):
146        no_options = ''
147        self.assertEqual(self.connection.options, no_options)
148
149    def testAttributePort(self):
150        def_port = 5432
151        self.assertIsInstance(self.connection.port, int)
152        self.assertEqual(self.connection.port, dbport or def_port)
153
154    def testAttributeProtocolVersion(self):
155        protocol_version = self.connection.protocol_version
156        self.assertIsInstance(protocol_version, int)
157        self.assertTrue(2 <= protocol_version < 4)
158
159    def testAttributeServerVersion(self):
160        server_version = self.connection.server_version
161        self.assertIsInstance(server_version, int)
162        self.assertTrue(70400 <= server_version < 100000)
163
164    def testAttributeStatus(self):
165        status_ok = 1
166        self.assertIsInstance(self.connection.status, int)
167        self.assertEqual(self.connection.status, status_ok)
168
169    def testAttributeUser(self):
170        no_user = 'Deprecated facility'
171        user = self.connection.user
172        self.assertTrue(user)
173        self.assertIsInstance(user, str)
174        self.assertNotEqual(user, no_user)
175
176    def testMethodQuery(self):
177        query = self.connection.query
178        query("select 1+1")
179        query("select 1+$1", (1,))
180        query("select 1+$1+$2", (2, 3))
181        query("select 1+$1+$2", [2, 3])
182
183    def testMethodQueryEmpty(self):
184        self.assertRaises(ValueError, self.connection.query, '')
185
186    def testMethodEndcopy(self):
187        try:
188            self.connection.endcopy()
189        except IOError:
190            pass
191
192    def testMethodClose(self):
193        self.connection.close()
194        try:
195            self.connection.reset()
196        except (pg.Error, TypeError):
197            pass
198        else:
199            self.fail('Reset should give an error for a closed connection')
200        self.assertRaises(pg.InternalError, self.connection.close)
201        try:
202            self.connection.query('select 1')
203        except (pg.Error, TypeError):
204            pass
205        else:
206            self.fail('Query should give an error for a closed connection')
207        self.connection = connect()
208
209    def testMethodReset(self):
210        query = self.connection.query
211        # check that client encoding gets reset
212        encoding = query('show client_encoding').getresult()[0][0].upper()
213        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
214        self.assertNotEqual(encoding, changed_encoding)
215        self.connection.query("set client_encoding=%s" % changed_encoding)
216        new_encoding = query('show client_encoding').getresult()[0][0].upper()
217        self.assertEqual(new_encoding, changed_encoding)
218        self.connection.reset()
219        new_encoding = query('show client_encoding').getresult()[0][0].upper()
220        self.assertNotEqual(new_encoding, changed_encoding)
221        self.assertEqual(new_encoding, encoding)
222
223    def testMethodCancel(self):
224        r = self.connection.cancel()
225        self.assertIsInstance(r, int)
226        self.assertEqual(r, 1)
227
228    def testCancelLongRunningThread(self):
229        errors = []
230
231        def sleep():
232            try:
233                self.connection.query('select pg_sleep(5)').getresult()
234            except pg.DatabaseError as error:
235                errors.append(str(error))
236
237        thread = threading.Thread(target=sleep)
238        t1 = time.time()
239        thread.start()  # run the query
240        while 1:  # make sure the query is really running
241            time.sleep(0.1)
242            if thread.is_alive() or time.time() - t1 > 5:
243                break
244        r = self.connection.cancel()  # cancel the running query
245        thread.join()  # wait for the thread to end
246        t2 = time.time()
247
248        self.assertIsInstance(r, int)
249        self.assertEqual(r, 1)  # return code should be 1
250        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
251        self.assertTrue(errors)
252
253    def testMethodFileNo(self):
254        r = self.connection.fileno()
255        self.assertIsInstance(r, int)
256        self.assertGreaterEqual(r, 0)
257
258    def testMethodTransaction(self):
259        transaction = self.connection.transaction
260        self.assertRaises(TypeError, transaction, None)
261        self.assertEqual(transaction(), pg.TRANS_IDLE)
262        self.connection.query('begin')
263        self.assertEqual(transaction(), pg.TRANS_INTRANS)
264        self.connection.query('rollback')
265        self.assertEqual(transaction(), pg.TRANS_IDLE)
266
267    def testMethodParameter(self):
268        parameter = self.connection.parameter
269        query = self.connection.query
270        self.assertRaises(TypeError, parameter)
271        r = parameter('this server setting does not exist')
272        self.assertIsNone(r)
273        s = query('show server_version').getresult()[0][0].upper()
274        self.assertIsNotNone(s)
275        r = parameter('server_version')
276        self.assertEqual(r, s)
277        s = query('show server_encoding').getresult()[0][0].upper()
278        self.assertIsNotNone(s)
279        r = parameter('server_encoding')
280        self.assertEqual(r, s)
281        s = query('show client_encoding').getresult()[0][0].upper()
282        self.assertIsNotNone(s)
283        r = parameter('client_encoding')
284        self.assertEqual(r, s)
285        s = query('show server_encoding').getresult()[0][0].upper()
286        self.assertIsNotNone(s)
287        r = parameter('server_encoding')
288        self.assertEqual(r, s)
289
290
291class TestSimpleQueries(unittest.TestCase):
292    """Test simple queries via a basic pg connection."""
293
294    def setUp(self):
295        self.c = connect()
296
297    def tearDown(self):
298        self.doCleanups()
299        self.c.close()
300
301    def testClassName(self):
302        r = self.c.query("select 1")
303        self.assertEqual(r.__class__.__name__, 'Query')
304
305    def testModuleName(self):
306        r = self.c.query("select 1")
307        self.assertEqual(r.__class__.__module__, 'pg')
308
309    def testStr(self):
310        q = ("select 1 as a, 'hello' as h, 'w' as world"
311            " union select 2, 'xyz', 'uvw'")
312        r = self.c.query(q)
313        self.assertEqual(str(r),
314            'a|  h  |world\n'
315            '-+-----+-----\n'
316            '1|hello|w    \n'
317            '2|xyz  |uvw  \n'
318            '(2 rows)')
319
320    def testRepr(self):
321        r = repr(self.c.query("select 1"))
322        self.assertTrue(r.startswith('<pg.Query object'), r)
323
324    def testSelect0(self):
325        q = "select 0"
326        self.c.query(q)
327
328    def testSelect0Semicolon(self):
329        q = "select 0;"
330        self.c.query(q)
331
332    def testSelectDotSemicolon(self):
333        q = "select .;"
334        self.assertRaises(pg.DatabaseError, self.c.query, q)
335
336    def testGetresult(self):
337        q = "select 0"
338        result = [(0,)]
339        r = self.c.query(q).getresult()
340        self.assertIsInstance(r, list)
341        v = r[0]
342        self.assertIsInstance(v, tuple)
343        self.assertIsInstance(v[0], int)
344        self.assertEqual(r, result)
345
346    def testGetresultLong(self):
347        q = "select 9876543210"
348        result = long(9876543210)
349        self.assertIsInstance(result, long)
350        v = self.c.query(q).getresult()[0][0]
351        self.assertIsInstance(v, long)
352        self.assertEqual(v, result)
353
354    def testGetresultDecimal(self):
355        q = "select 98765432109876543210"
356        result = Decimal(98765432109876543210)
357        v = self.c.query(q).getresult()[0][0]
358        self.assertIsInstance(v, Decimal)
359        self.assertEqual(v, result)
360
361    def testGetresultString(self):
362        result = 'Hello, world!'
363        q = "select '%s'" % result
364        v = self.c.query(q).getresult()[0][0]
365        self.assertIsInstance(v, str)
366        self.assertEqual(v, result)
367
368    def testDictresult(self):
369        q = "select 0 as alias0"
370        result = [{'alias0': 0}]
371        r = self.c.query(q).dictresult()
372        self.assertIsInstance(r, list)
373        v = r[0]
374        self.assertIsInstance(v, dict)
375        self.assertIsInstance(v['alias0'], int)
376        self.assertEqual(r, result)
377
378    def testDictresultLong(self):
379        q = "select 9876543210 as longjohnsilver"
380        result = long(9876543210)
381        self.assertIsInstance(result, long)
382        v = self.c.query(q).dictresult()[0]['longjohnsilver']
383        self.assertIsInstance(v, long)
384        self.assertEqual(v, result)
385
386    def testDictresultDecimal(self):
387        q = "select 98765432109876543210 as longjohnsilver"
388        result = Decimal(98765432109876543210)
389        v = self.c.query(q).dictresult()[0]['longjohnsilver']
390        self.assertIsInstance(v, Decimal)
391        self.assertEqual(v, result)
392
393    def testDictresultString(self):
394        result = 'Hello, world!'
395        q = "select '%s' as greeting" % result
396        v = self.c.query(q).dictresult()[0]['greeting']
397        self.assertIsInstance(v, str)
398        self.assertEqual(v, result)
399
400    def testNamedresult(self):
401        q = "select 0 as alias0"
402        result = [(0,)]
403        r = self.c.query(q).namedresult()
404        self.assertEqual(r, result)
405        v = r[0]
406        self.assertEqual(v._fields, ('alias0',))
407        self.assertEqual(v.alias0, 0)
408
409    def testGet3Cols(self):
410        q = "select 1,2,3"
411        result = [(1, 2, 3)]
412        r = self.c.query(q).getresult()
413        self.assertEqual(r, result)
414
415    def testGet3DictCols(self):
416        q = "select 1 as a,2 as b,3 as c"
417        result = [dict(a=1, b=2, c=3)]
418        r = self.c.query(q).dictresult()
419        self.assertEqual(r, result)
420
421    def testGet3NamedCols(self):
422        q = "select 1 as a,2 as b,3 as c"
423        result = [(1, 2, 3)]
424        r = self.c.query(q).namedresult()
425        self.assertEqual(r, result)
426        v = r[0]
427        self.assertEqual(v._fields, ('a', 'b', 'c'))
428        self.assertEqual(v.b, 2)
429
430    def testGet3Rows(self):
431        q = "select 3 union select 1 union select 2 order by 1"
432        result = [(1,), (2,), (3,)]
433        r = self.c.query(q).getresult()
434        self.assertEqual(r, result)
435
436    def testGet3DictRows(self):
437        q = ("select 3 as alias3"
438            " union select 1 union select 2 order by 1")
439        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
440        r = self.c.query(q).dictresult()
441        self.assertEqual(r, result)
442
443    def testGet3NamedRows(self):
444        q = ("select 3 as alias3"
445            " union select 1 union select 2 order by 1")
446        result = [(1,), (2,), (3,)]
447        r = self.c.query(q).namedresult()
448        self.assertEqual(r, result)
449        for v in r:
450            self.assertEqual(v._fields, ('alias3',))
451
452    def testDictresultNames(self):
453        q = "select 'MixedCase' as MixedCaseAlias"
454        result = [{'mixedcasealias': 'MixedCase'}]
455        r = self.c.query(q).dictresult()
456        self.assertEqual(r, result)
457        q = "select 'MixedCase' as \"MixedCaseAlias\""
458        result = [{'MixedCaseAlias': 'MixedCase'}]
459        r = self.c.query(q).dictresult()
460        self.assertEqual(r, result)
461
462    def testNamedresultNames(self):
463        q = "select 'MixedCase' as MixedCaseAlias"
464        result = [('MixedCase',)]
465        r = self.c.query(q).namedresult()
466        self.assertEqual(r, result)
467        v = r[0]
468        self.assertEqual(v._fields, ('mixedcasealias',))
469        self.assertEqual(v.mixedcasealias, 'MixedCase')
470        q = "select 'MixedCase' as \"MixedCaseAlias\""
471        r = self.c.query(q).namedresult()
472        self.assertEqual(r, result)
473        v = r[0]
474        self.assertEqual(v._fields, ('MixedCaseAlias',))
475        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
476
477    def testBigGetresult(self):
478        num_cols = 100
479        num_rows = 100
480        q = "select " + ','.join(map(str, range(num_cols)))
481        q = ' union all '.join((q,) * num_rows)
482        r = self.c.query(q).getresult()
483        result = [tuple(range(num_cols))] * num_rows
484        self.assertEqual(r, result)
485
486    def testListfields(self):
487        q = ('select 0 as a, 0 as b, 0 as c,'
488            ' 0 as c, 0 as b, 0 as a,'
489            ' 0 as lowercase, 0 as UPPERCASE,'
490            ' 0 as MixedCase, 0 as "MixedCase",'
491            ' 0 as a_long_name_with_underscores,'
492            ' 0 as "A long name with Blanks"')
493        r = self.c.query(q).listfields()
494        result = ('a', 'b', 'c', 'c', 'b', 'a',
495            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
496            'a_long_name_with_underscores',
497            'A long name with Blanks')
498        self.assertEqual(r, result)
499
500    def testFieldname(self):
501        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
502        r = self.c.query(q).fieldname(2)
503        self.assertEqual(r, 'x')
504        r = self.c.query(q).fieldname(3)
505        self.assertEqual(r, 'y')
506
507    def testFieldnum(self):
508        q = "select 1 as x"
509        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
510        q = "select 1 as x"
511        r = self.c.query(q).fieldnum('x')
512        self.assertIsInstance(r, int)
513        self.assertEqual(r, 0)
514        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
515        r = self.c.query(q).fieldnum('x')
516        self.assertIsInstance(r, int)
517        self.assertEqual(r, 2)
518        r = self.c.query(q).fieldnum('y')
519        self.assertIsInstance(r, int)
520        self.assertEqual(r, 3)
521
522    def testNtuples(self):
523        q = "select 1 where false"
524        r = self.c.query(q).ntuples()
525        self.assertIsInstance(r, int)
526        self.assertEqual(r, 0)
527        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
528            " union select 5 as a, 6 as b, 7 as c, 8 as d")
529        r = self.c.query(q).ntuples()
530        self.assertIsInstance(r, int)
531        self.assertEqual(r, 2)
532        q = ("select 1 union select 2 union select 3"
533            " union select 4 union select 5 union select 6")
534        r = self.c.query(q).ntuples()
535        self.assertIsInstance(r, int)
536        self.assertEqual(r, 6)
537
538    def testQuery(self):
539        query = self.c.query
540        query("drop table if exists test_table")
541        self.addCleanup(query, "drop table test_table")
542        q = "create table test_table (n integer) with oids"
543        r = query(q)
544        self.assertIsNone(r)
545        q = "insert into test_table values (1)"
546        r = query(q)
547        self.assertIsInstance(r, int)
548        q = "insert into test_table select 2"
549        r = query(q)
550        self.assertIsInstance(r, int)
551        oid = r
552        q = "select oid from test_table where n=2"
553        r = query(q).getresult()
554        self.assertEqual(len(r), 1)
555        r = r[0]
556        self.assertEqual(len(r), 1)
557        r = r[0]
558        self.assertIsInstance(r, int)
559        self.assertEqual(r, oid)
560        q = "insert into test_table select 3 union select 4 union select 5"
561        r = query(q)
562        self.assertIsInstance(r, str)
563        self.assertEqual(r, '3')
564        q = "update test_table set n=4 where n<5"
565        r = query(q)
566        self.assertIsInstance(r, str)
567        self.assertEqual(r, '4')
568        q = "delete from test_table"
569        r = query(q)
570        self.assertIsInstance(r, str)
571        self.assertEqual(r, '5')
572
573
574class TestUnicodeQueries(unittest.TestCase):
575    """Test unicode strings as queries via a basic pg connection."""
576
577    def setUp(self):
578        self.c = connect()
579        self.c.query('set client_encoding=utf8')
580
581    def tearDown(self):
582        self.c.close()
583
584    def testGetresulAscii(self):
585        result = u'Hello, world!'
586        q = u"select '%s'" % result
587        v = self.c.query(q).getresult()[0][0]
588        self.assertIsInstance(v, str)
589        self.assertEqual(v, result)
590
591    def testDictresulAscii(self):
592        result = u'Hello, world!'
593        q = u"select '%s' as greeting" % result
594        v = self.c.query(q).dictresult()[0]['greeting']
595        self.assertIsInstance(v, str)
596        self.assertEqual(v, result)
597
598    def testGetresultUtf8(self):
599        result = u'Hello, wörld & ЌОр!'
600        q = u"select '%s'" % result
601        if not unicode_strings:
602            result = result.encode('utf8')
603        # pass the query as unicode
604        try:
605            v = self.c.query(q).getresult()[0][0]
606        except(pg.DataError, pg.NotSupportedError):
607            self.skipTest("database does not support utf8")
608        self.assertIsInstance(v, str)
609        self.assertEqual(v, result)
610        q = q.encode('utf8')
611        # pass the query as bytes
612        v = self.c.query(q).getresult()[0][0]
613        self.assertIsInstance(v, str)
614        self.assertEqual(v, result)
615
616    def testDictresultUtf8(self):
617        result = u'Hello, wörld & ЌОр!'
618        q = u"select '%s' as greeting" % result
619        if not unicode_strings:
620            result = result.encode('utf8')
621        try:
622            v = self.c.query(q).dictresult()[0]['greeting']
623        except (pg.DataError, pg.NotSupportedError):
624            self.skipTest("database does not support utf8")
625        self.assertIsInstance(v, str)
626        self.assertEqual(v, result)
627        q = q.encode('utf8')
628        v = self.c.query(q).dictresult()[0]['greeting']
629        self.assertIsInstance(v, str)
630        self.assertEqual(v, result)
631
632    def testDictresultLatin1(self):
633        try:
634            self.c.query('set client_encoding=latin1')
635        except (pg.DataError, pg.NotSupportedError):
636            self.skipTest("database does not support latin1")
637        result = u'Hello, wörld!'
638        q = u"select '%s'" % result
639        if not unicode_strings:
640            result = result.encode('latin1')
641        v = self.c.query(q).getresult()[0][0]
642        self.assertIsInstance(v, str)
643        self.assertEqual(v, result)
644        q = q.encode('latin1')
645        v = self.c.query(q).getresult()[0][0]
646        self.assertIsInstance(v, str)
647        self.assertEqual(v, result)
648
649    def testDictresultLatin1(self):
650        try:
651            self.c.query('set client_encoding=latin1')
652        except (pg.DataError, pg.NotSupportedError):
653            self.skipTest("database does not support latin1")
654        result = u'Hello, wörld!'
655        q = u"select '%s' as greeting" % result
656        if not unicode_strings:
657            result = result.encode('latin1')
658        v = self.c.query(q).dictresult()[0]['greeting']
659        self.assertIsInstance(v, str)
660        self.assertEqual(v, result)
661        q = q.encode('latin1')
662        v = self.c.query(q).dictresult()[0]['greeting']
663        self.assertIsInstance(v, str)
664        self.assertEqual(v, result)
665
666    def testGetresultCyrillic(self):
667        try:
668            self.c.query('set client_encoding=iso_8859_5')
669        except (pg.DataError, pg.NotSupportedError):
670            self.skipTest("database does not support cyrillic")
671        result = u'Hello, ЌОр!'
672        q = u"select '%s'" % result
673        if not unicode_strings:
674            result = result.encode('cyrillic')
675        v = self.c.query(q).getresult()[0][0]
676        self.assertIsInstance(v, str)
677        self.assertEqual(v, result)
678        q = q.encode('cyrillic')
679        v = self.c.query(q).getresult()[0][0]
680        self.assertIsInstance(v, str)
681        self.assertEqual(v, result)
682
683    def testDictresultCyrillic(self):
684        try:
685            self.c.query('set client_encoding=iso_8859_5')
686        except (pg.DataError, pg.NotSupportedError):
687            self.skipTest("database does not support cyrillic")
688        result = u'Hello, ЌОр!'
689        q = u"select '%s' as greeting" % result
690        if not unicode_strings:
691            result = result.encode('cyrillic')
692        v = self.c.query(q).dictresult()[0]['greeting']
693        self.assertIsInstance(v, str)
694        self.assertEqual(v, result)
695        q = q.encode('cyrillic')
696        v = self.c.query(q).dictresult()[0]['greeting']
697        self.assertIsInstance(v, str)
698        self.assertEqual(v, result)
699
700    def testGetresultLatin9(self):
701        try:
702            self.c.query('set client_encoding=latin9')
703        except (pg.DataError, pg.NotSupportedError):
704            self.skipTest("database does not support latin9")
705        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
706        q = u"select '%s'" % result
707        if not unicode_strings:
708            result = result.encode('latin9')
709        v = self.c.query(q).getresult()[0][0]
710        self.assertIsInstance(v, str)
711        self.assertEqual(v, result)
712        q = q.encode('latin9')
713        v = self.c.query(q).getresult()[0][0]
714        self.assertIsInstance(v, str)
715        self.assertEqual(v, result)
716
717    def testDictresultLatin9(self):
718        try:
719            self.c.query('set client_encoding=latin9')
720        except (pg.DataError, pg.NotSupportedError):
721            self.skipTest("database does not support latin9")
722        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
723        q = u"select '%s' as menu" % result
724        if not unicode_strings:
725            result = result.encode('latin9')
726        v = self.c.query(q).dictresult()[0]['menu']
727        self.assertIsInstance(v, str)
728        self.assertEqual(v, result)
729        q = q.encode('latin9')
730        v = self.c.query(q).dictresult()[0]['menu']
731        self.assertIsInstance(v, str)
732        self.assertEqual(v, result)
733
734
735class TestParamQueries(unittest.TestCase):
736    """Test queries with parameters via a basic pg connection."""
737
738    def setUp(self):
739        self.c = connect()
740        self.c.query('set client_encoding=utf8')
741
742    def tearDown(self):
743        self.c.close()
744
745    def testQueryWithNoneParam(self):
746        self.assertRaises(TypeError, self.c.query, "select $1", None)
747        self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None)
748        self.assertEqual(self.c.query("select $1::integer", (None,)
749            ).getresult(), [(None,)])
750        self.assertEqual(self.c.query("select $1::text", [None]
751            ).getresult(), [(None,)])
752        self.assertEqual(self.c.query("select $1::text", [[None]]
753            ).getresult(), [(None,)])
754
755    def testQueryWithBoolParams(self, bool_enabled=None):
756        query = self.c.query
757        if bool_enabled is not None:
758            bool_enabled_default = pg.get_bool()
759            pg.set_bool(bool_enabled)
760        try:
761            bool_on = bool_enabled or bool_enabled is None
762            v_false, v_true = (False, True) if bool_on else 'ft'
763            r_false, r_true = [(v_false,)], [(v_true,)]
764            self.assertEqual(query("select false").getresult(), r_false)
765            self.assertEqual(query("select true").getresult(), r_true)
766            q = "select $1::bool"
767            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
768            self.assertEqual(query(q, ('f',)).getresult(), r_false)
769            self.assertEqual(query(q, ('t',)).getresult(), r_true)
770            self.assertEqual(query(q, ('false',)).getresult(), r_false)
771            self.assertEqual(query(q, ('true',)).getresult(), r_true)
772            self.assertEqual(query(q, ('n',)).getresult(), r_false)
773            self.assertEqual(query(q, ('y',)).getresult(), r_true)
774            self.assertEqual(query(q, (0,)).getresult(), r_false)
775            self.assertEqual(query(q, (1,)).getresult(), r_true)
776            self.assertEqual(query(q, (False,)).getresult(), r_false)
777            self.assertEqual(query(q, (True,)).getresult(), r_true)
778        finally:
779            if bool_enabled is not None:
780                pg.set_bool(bool_enabled_default)
781
782    def testQueryWithBoolParamsNotDefault(self):
783        self.testQueryWithBoolParams(bool_enabled=not pg.get_bool())
784
785    def testQueryWithIntParams(self):
786        query = self.c.query
787        self.assertEqual(query("select 1+1").getresult(), [(2,)])
788        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
789        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
790        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
791        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
792        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
793            [(Decimal('2'),)])
794        self.assertEqual(query("select 1, $1::integer", (2,)
795            ).getresult(), [(1, 2)])
796        self.assertEqual(query("select 1 union select $1::integer", (2,)
797            ).getresult(), [(1,), (2,)])
798        self.assertEqual(query("select $1::integer+$2", (1, 2)
799            ).getresult(), [(3,)])
800        self.assertEqual(query("select $1::integer+$2", [1, 2]
801            ).getresult(), [(3,)])
802        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))
803            ).getresult(), [(15,)])
804
805    def testQueryWithStrParams(self):
806        query = self.c.query
807        self.assertEqual(query("select $1||', world!'", ('Hello',)
808            ).getresult(), [('Hello, world!',)])
809        self.assertEqual(query("select $1||', world!'", ['Hello']
810            ).getresult(), [('Hello, world!',)])
811        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
812            ).getresult(), [('Hello, world!',)])
813        self.assertEqual(query("select $1::text", ('Hello, world!',)
814            ).getresult(), [('Hello, world!',)])
815        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
816            ).getresult(), [('Hello', 'world')])
817        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
818            ).getresult(), [('Hello', 'world')])
819        self.assertEqual(query("select $1::text union select $2::text",
820            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
821        try:
822            query("select 'wörld'")
823        except (pg.DataError, pg.NotSupportedError):
824            self.skipTest('database does not support utf8')
825        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
826            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
827
828    def testQueryWithUnicodeParams(self):
829        query = self.c.query
830        try:
831            query('set client_encoding=utf8')
832            query("select 'wörld'").getresult()[0][0] == 'wörld'
833        except (pg.DataError, pg.NotSupportedError):
834            self.skipTest("database does not support utf8")
835        self.assertEqual(query("select $1||', '||$2||'!'",
836            ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)])
837
838    def testQueryWithUnicodeParamsLatin1(self):
839        query = self.c.query
840        try:
841            query('set client_encoding=latin1')
842            query("select 'wörld'").getresult()[0][0] == 'wörld'
843        except (pg.DataError, pg.NotSupportedError):
844            self.skipTest("database does not support latin1")
845        r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult()
846        if unicode_strings:
847            self.assertEqual(r, [('Hello, wörld!',)])
848        else:
849            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
850        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
851            ('Hello', u'ЌОр'))
852        query('set client_encoding=iso_8859_1')
853        r = query("select $1||', '||$2||'!'",
854            ('Hello', u'wörld')).getresult()
855        if unicode_strings:
856            self.assertEqual(r, [('Hello, wörld!',)])
857        else:
858            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
859        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
860            ('Hello', u'ЌОр'))
861        query('set client_encoding=sql_ascii')
862        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
863            ('Hello', u'wörld'))
864
865    def testQueryWithUnicodeParamsCyrillic(self):
866        query = self.c.query
867        try:
868            query('set client_encoding=iso_8859_5')
869            query("select 'ЌОр'").getresult()[0][0] == 'ЌОр'
870        except (pg.DataError, pg.NotSupportedError):
871            self.skipTest("database does not support cyrillic")
872        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
873            ('Hello', u'wörld'))
874        r = query("select $1||', '||$2||'!'",
875            ('Hello', u'ЌОр')).getresult()
876        if unicode_strings:
877            self.assertEqual(r, [('Hello, ЌОр!',)])
878        else:
879            self.assertEqual(r, [(u'Hello, ЌОр!'.encode('cyrillic'),)])
880        query('set client_encoding=sql_ascii')
881        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
882            ('Hello', u'ЌОр!'))
883
884    def testQueryWithMixedParams(self):
885        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
886            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
887        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
888            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
889
890    def testQueryWithDuplicateParams(self):
891        self.assertRaises(pg.ProgrammingError,
892            self.c.query, "select $1+$1", (1,))
893        self.assertRaises(pg.ProgrammingError,
894            self.c.query, "select $1+$1", (1, 2))
895
896    def testQueryWithZeroParams(self):
897        self.assertEqual(self.c.query("select 1+1", []
898            ).getresult(), [(2,)])
899
900    def testQueryWithGarbage(self):
901        garbage = r"'\{}+()-#[]oo324"
902        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
903            ).dictresult(), [{'garbage': garbage}])
904
905
906class TestQueryResultTypes(unittest.TestCase):
907    """Test proper result types via a basic pg connection."""
908
909    def setUp(self):
910        self.c = connect()
911        self.c.query('set client_encoding=utf8')
912        self.c.query("set datestyle='ISO,YMD'")
913
914    def tearDown(self):
915        self.c.close()
916
917    def assert_proper_cast(self, value, pgtype, pytype):
918        q = 'select $1::%s' % (pgtype,)
919        try:
920            r = self.c.query(q, (value,)).getresult()[0][0]
921        except pg.ProgrammingError:
922            if pgtype in ('json', 'jsonb'):
923                self.skipTest('database does not support json')
924        self.assertIsInstance(r, pytype)
925        if isinstance(value, str):
926            if not value or ' ' in value or '{' in value:
927                value = '"%s"' % value
928        value = '{%s}' % value
929        r = self.c.query(q + '[]', (value,)).getresult()[0][0]
930        if pgtype.startswith(('date', 'time', 'interval')):
931            # arrays of these are casted by the DB wrapper only
932            self.assertEqual(r, value)
933        else:
934            self.assertIsInstance(r, list)
935            self.assertEqual(len(r), 1)
936            self.assertIsInstance(r[0], pytype)
937
938    def testInt(self):
939        self.assert_proper_cast(0, 'int', int)
940        self.assert_proper_cast(0, 'smallint', int)
941        self.assert_proper_cast(0, 'oid', int)
942        self.assert_proper_cast(0, 'cid', int)
943        self.assert_proper_cast(0, 'xid', int)
944
945    def testLong(self):
946        self.assert_proper_cast(0, 'bigint', long)
947
948    def testFloat(self):
949        self.assert_proper_cast(0, 'float', float)
950        self.assert_proper_cast(0, 'real', float)
951        self.assert_proper_cast(0, 'double', float)
952        self.assert_proper_cast(0, 'double precision', float)
953        self.assert_proper_cast('infinity', 'float', float)
954
955    def testFloat(self):
956        decimal = pg.get_decimal()
957        self.assert_proper_cast(decimal(0), 'numeric', decimal)
958        self.assert_proper_cast(decimal(0), 'decimal', decimal)
959
960    def testMoney(self):
961        decimal = pg.get_decimal()
962        self.assert_proper_cast(decimal('0'), 'money', decimal)
963
964    def testBool(self):
965        bool_type = bool if pg.get_bool() else str
966        self.assert_proper_cast('f', 'bool', bool_type)
967
968    def testDate(self):
969        self.assert_proper_cast('1956-01-31', 'date', str)
970        self.assert_proper_cast('10:20:30', 'interval', str)
971        self.assert_proper_cast('08:42:15', 'time', str)
972        self.assert_proper_cast('08:42:15+00', 'timetz', str)
973        self.assert_proper_cast('1956-01-31 08:42:15', 'timestamp', str)
974        self.assert_proper_cast('1956-01-31 08:42:15+00', 'timestamptz', str)
975
976    def testText(self):
977        self.assert_proper_cast('', 'text', str)
978        self.assert_proper_cast('', 'char', str)
979        self.assert_proper_cast('', 'bpchar', str)
980        self.assert_proper_cast('', 'varchar', str)
981
982    def testBytea(self):
983        self.assert_proper_cast('', 'bytea', bytes)
984
985    def testJson(self):
986        self.assert_proper_cast('{}', 'json', dict)
987
988
989class TestInserttable(unittest.TestCase):
990    """Test inserttable method."""
991
992    cls_set_up = False
993
994    @classmethod
995    def setUpClass(cls):
996        c = connect()
997        c.query("drop table if exists test cascade")
998        c.query("create table test ("
999            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
1000            "d numeric, f4 real, f8 double precision, m money,"
1001            "c char(1), v4 varchar(4), c4 char(4), t text)")
1002        # Check whether the test database uses SQL_ASCII - this means
1003        # that it does not consider encoding when calculating lengths.
1004        c.query("set client_encoding=utf8")
1005        try:
1006            c.query("select 'À'")
1007        except (pg.DataError, pg.NotSupportedError):
1008            cls.has_encoding = False
1009        else:
1010            cls.has_encoding = c.query(
1011                "select length('À') - length('a')").getresult()[0][0] == 0
1012        c.close()
1013        cls.cls_set_up = True
1014
1015    @classmethod
1016    def tearDownClass(cls):
1017        c = connect()
1018        c.query("drop table test cascade")
1019        c.close()
1020
1021    def setUp(self):
1022        self.assertTrue(self.cls_set_up)
1023        self.c = connect()
1024        self.c.query("set client_encoding=utf8")
1025        self.c.query("set datestyle='ISO,YMD'")
1026        self.c.query("set lc_monetary='C'")
1027
1028    def tearDown(self):
1029        self.c.query("truncate table test")
1030        self.c.close()
1031
1032    data = [
1033        (-1, -1, long(-1), True, '1492-10-12', '08:30:00',
1034            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
1035        (0, 0, long(0), False, '1607-04-14', '09:00:00',
1036            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
1037        (1, 1, long(1), True, '1801-03-04', '03:45:00',
1038            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
1039        (2, 2, long(2), False, '1903-12-17', '11:22:00',
1040            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
1041
1042    @classmethod
1043    def db_len(cls, s, encoding):
1044        if cls.has_encoding:
1045            s = s if isinstance(s, unicode) else s.decode(encoding)
1046        else:
1047            s = s.encode(encoding) if isinstance(s, unicode) else s
1048        return len(s)
1049
1050    def get_back(self, encoding='utf-8'):
1051        """Convert boolean and decimal values back."""
1052        data = []
1053        for row in self.c.query("select * from test order by 1").getresult():
1054            self.assertIsInstance(row, tuple)
1055            row = list(row)
1056            if row[0] is not None:  # smallint
1057                self.assertIsInstance(row[0], int)
1058            if row[1] is not None:  # integer
1059                self.assertIsInstance(row[1], int)
1060            if row[2] is not None:  # bigint
1061                self.assertIsInstance(row[2], long)
1062            if row[3] is not None:  # boolean
1063                self.assertIsInstance(row[3], bool)
1064            if row[4] is not None:  # date
1065                self.assertIsInstance(row[4], str)
1066                self.assertTrue(row[4].replace('-', '').isdigit())
1067            if row[5] is not None:  # time
1068                self.assertIsInstance(row[5], str)
1069                self.assertTrue(row[5].replace(':', '').isdigit())
1070            if row[6] is not None:  # numeric
1071                self.assertIsInstance(row[6], Decimal)
1072                row[6] = float(row[6])
1073            if row[7] is not None:  # real
1074                self.assertIsInstance(row[7], float)
1075            if row[8] is not None:  # double precision
1076                self.assertIsInstance(row[8], float)
1077                row[8] = float(row[8])
1078            if row[9] is not None:  # money
1079                self.assertIsInstance(row[9], Decimal)
1080                row[9] = str(float(row[9]))
1081            if row[10] is not None:  # char(1)
1082                self.assertIsInstance(row[10], str)
1083                self.assertEqual(self.db_len(row[10], encoding), 1)
1084            if row[11] is not None:  # varchar(4)
1085                self.assertIsInstance(row[11], str)
1086                self.assertLessEqual(self.db_len(row[11], encoding), 4)
1087            if row[12] is not None:  # char(4)
1088                self.assertIsInstance(row[12], str)
1089                self.assertEqual(self.db_len(row[12], encoding), 4)
1090                row[12] = row[12].rstrip()
1091            if row[13] is not None:  # text
1092                self.assertIsInstance(row[13], str)
1093            row = tuple(row)
1094            data.append(row)
1095        return data
1096
1097    def testInserttable1Row(self):
1098        data = self.data[2:3]
1099        self.c.inserttable('test', data)
1100        self.assertEqual(self.get_back(), data)
1101
1102    def testInserttable4Rows(self):
1103        data = self.data
1104        self.c.inserttable('test', data)
1105        self.assertEqual(self.get_back(), data)
1106
1107    def testInserttableMultipleRows(self):
1108        num_rows = 100
1109        data = self.data[2:3] * num_rows
1110        self.c.inserttable('test', data)
1111        r = self.c.query("select count(*) from test").getresult()[0][0]
1112        self.assertEqual(r, num_rows)
1113
1114    def testInserttableMultipleCalls(self):
1115        num_rows = 10
1116        data = self.data[2:3]
1117        for _i in range(num_rows):
1118            self.c.inserttable('test', data)
1119        r = self.c.query("select count(*) from test").getresult()[0][0]
1120        self.assertEqual(r, num_rows)
1121
1122    def testInserttableNullValues(self):
1123        data = [(None,) * 14] * 100
1124        self.c.inserttable('test', data)
1125        self.assertEqual(self.get_back(), data)
1126
1127    def testInserttableMaxValues(self):
1128        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
1129            True, '2999-12-31', '11:59:59', 1e99,
1130            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
1131            "1", "1234", "1234", "1234" * 100)]
1132        self.c.inserttable('test', data)
1133        self.assertEqual(self.get_back(), data)
1134
1135    def testInserttableByteValues(self):
1136        try:
1137            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1138        except pg.DataError:
1139            self.skipTest("database does not support utf8")
1140        # non-ascii chars do not fit in char(1) when there is no encoding
1141        c = u'€' if self.has_encoding else u'$'
1142        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1143            0.0, 0.0, 0.0, u'0.0',
1144            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1145        row_bytes = tuple(s.encode('utf-8')
1146            if isinstance(s, unicode) else s for s in row_unicode)
1147        data = [row_bytes] * 2
1148        self.c.inserttable('test', data)
1149        if unicode_strings:
1150            data = [row_unicode] * 2
1151        self.assertEqual(self.get_back(), data)
1152
1153    def testInserttableUnicodeUtf8(self):
1154        try:
1155            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1156        except pg.DataError:
1157            self.skipTest("database does not support utf8")
1158        # non-ascii chars do not fit in char(1) when there is no encoding
1159        c = u'€' if self.has_encoding else u'$'
1160        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1161            0.0, 0.0, 0.0, u'0.0',
1162            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1163        data = [row_unicode] * 2
1164        self.c.inserttable('test', data)
1165        if not unicode_strings:
1166            row_bytes = tuple(s.encode('utf-8')
1167                if isinstance(s, unicode) else s for s in row_unicode)
1168            data = [row_bytes] * 2
1169        self.assertEqual(self.get_back(), data)
1170
1171    def testInserttableUnicodeLatin1(self):
1172
1173        try:
1174            self.c.query("set client_encoding=latin1")
1175            self.c.query("select 'Â¥'")
1176        except (pg.DataError, pg.NotSupportedError):
1177            self.skipTest("database does not support latin1")
1178        # non-ascii chars do not fit in char(1) when there is no encoding
1179        c = u'€' if self.has_encoding else u'$'
1180        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1181            0.0, 0.0, 0.0, u'0.0',
1182            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1183        data = [row_unicode]
1184        # cannot encode € sign with latin1 encoding
1185        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1186        row_unicode = tuple(s.replace(u'€', u'Â¥')
1187            if isinstance(s, unicode) else s for s in row_unicode)
1188        data = [row_unicode] * 2
1189        self.c.inserttable('test', data)
1190        if not unicode_strings:
1191            row_bytes = tuple(s.encode('latin1')
1192                if isinstance(s, unicode) else s for s in row_unicode)
1193            data = [row_bytes] * 2
1194        self.assertEqual(self.get_back('latin1'), data)
1195
1196    def testInserttableUnicodeLatin9(self):
1197        try:
1198            self.c.query("set client_encoding=latin9")
1199            self.c.query("select '€'")
1200        except (pg.DataError, pg.NotSupportedError):
1201            self.skipTest("database does not support latin9")
1202            return
1203        # non-ascii chars do not fit in char(1) when there is no encoding
1204        c = u'€' if self.has_encoding else u'$'
1205        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1206            0.0, 0.0, 0.0, u'0.0',
1207            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1208        data = [row_unicode] * 2
1209        self.c.inserttable('test', data)
1210        if not unicode_strings:
1211            row_bytes = tuple(s.encode('latin9')
1212                if isinstance(s, unicode) else s for s in row_unicode)
1213            data = [row_bytes] * 2
1214        self.assertEqual(self.get_back('latin9'), data)
1215
1216    def testInserttableNoEncoding(self):
1217        self.c.query("set client_encoding=sql_ascii")
1218        # non-ascii chars do not fit in char(1) when there is no encoding
1219        c = u'€' if self.has_encoding else u'$'
1220        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1221            0.0, 0.0, 0.0, u'0.0',
1222            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1223        data = [row_unicode]
1224        # cannot encode non-ascii unicode without a specific encoding
1225        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1226
1227
1228class TestDirectSocketAccess(unittest.TestCase):
1229    """Test copy command with direct socket access."""
1230
1231    cls_set_up = False
1232
1233    @classmethod
1234    def setUpClass(cls):
1235        c = connect()
1236        c.query("drop table if exists test cascade")
1237        c.query("create table test (i int, v varchar(16))")
1238        c.close()
1239        cls.cls_set_up = True
1240
1241    @classmethod
1242    def tearDownClass(cls):
1243        c = connect()
1244        c.query("drop table test cascade")
1245        c.close()
1246
1247    def setUp(self):
1248        self.assertTrue(self.cls_set_up)
1249        self.c = connect()
1250        self.c.query("set client_encoding=utf8")
1251
1252    def tearDown(self):
1253        self.c.query("truncate table test")
1254        self.c.close()
1255
1256    def testPutline(self):
1257        putline = self.c.putline
1258        query = self.c.query
1259        data = list(enumerate("apple pear plum cherry banana".split()))
1260        query("copy test from stdin")
1261        try:
1262            for i, v in data:
1263                putline("%d\t%s\n" % (i, v))
1264            putline("\\.\n")
1265        finally:
1266            self.c.endcopy()
1267        r = query("select * from test").getresult()
1268        self.assertEqual(r, data)
1269
1270    def testPutlineBytesAndUnicode(self):
1271        putline = self.c.putline
1272        query = self.c.query
1273        try:
1274            query("select 'kÀse+wÃŒrstel'")
1275        except (pg.DataError, pg.NotSupportedError):
1276            self.skipTest('database does not support utf8')
1277        query("copy test from stdin")
1278        try:
1279            putline(u"47\tkÀse\n".encode('utf8'))
1280            putline("35\twÃŒrstel\n")
1281            putline(b"\\.\n")
1282        finally:
1283            self.c.endcopy()
1284        r = query("select * from test").getresult()
1285        self.assertEqual(r, [(47, 'kÀse'), (35, 'wÃŒrstel')])
1286
1287    def testGetline(self):
1288        getline = self.c.getline
1289        query = self.c.query
1290        data = list(enumerate("apple banana pear plum strawberry".split()))
1291        n = len(data)
1292        self.c.inserttable('test', data)
1293        query("copy test to stdout")
1294        try:
1295            for i in range(n + 2):
1296                v = getline()
1297                if i < n:
1298                    self.assertEqual(v, '%d\t%s' % data[i])
1299                elif i == n:
1300                    self.assertEqual(v, '\\.')
1301                else:
1302                    self.assertIsNone(v)
1303        finally:
1304            try:
1305                self.c.endcopy()
1306            except IOError:
1307                pass
1308
1309    def testGetlineBytesAndUnicode(self):
1310        getline = self.c.getline
1311        query = self.c.query
1312        try:
1313            query("select 'kÀse+wÃŒrstel'")
1314        except (pg.DataError, pg.NotSupportedError):
1315            self.skipTest('database does not support utf8')
1316        data = [(54, u'kÀse'.encode('utf8')), (73, u'wÃŒrstel')]
1317        self.c.inserttable('test', data)
1318        query("copy test to stdout")
1319        try:
1320            v = getline()
1321            self.assertIsInstance(v, str)
1322            self.assertEqual(v, '54\tkÀse')
1323            v = getline()
1324            self.assertIsInstance(v, str)
1325            self.assertEqual(v, '73\twÃŒrstel')
1326            self.assertEqual(getline(), '\\.')
1327            self.assertIsNone(getline())
1328        finally:
1329            try:
1330                self.c.endcopy()
1331            except IOError:
1332                pass
1333
1334    def testParameterChecks(self):
1335        self.assertRaises(TypeError, self.c.putline)
1336        self.assertRaises(TypeError, self.c.getline, 'invalid')
1337        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
1338
1339
1340class TestNotificatons(unittest.TestCase):
1341    """Test notification support."""
1342
1343    def setUp(self):
1344        self.c = connect()
1345
1346    def tearDown(self):
1347        self.doCleanups()
1348        self.c.close()
1349
1350    def testGetNotify(self):
1351        getnotify = self.c.getnotify
1352        query = self.c.query
1353        self.assertIsNone(getnotify())
1354        query('listen test_notify')
1355        try:
1356            self.assertIsNone(self.c.getnotify())
1357            query("notify test_notify")
1358            r = getnotify()
1359            self.assertIsInstance(r, tuple)
1360            self.assertEqual(len(r), 3)
1361            self.assertIsInstance(r[0], str)
1362            self.assertIsInstance(r[1], int)
1363            self.assertIsInstance(r[2], str)
1364            self.assertEqual(r[0], 'test_notify')
1365            self.assertEqual(r[2], '')
1366            self.assertIsNone(self.c.getnotify())
1367            query("notify test_notify, 'test_payload'")
1368            r = getnotify()
1369            self.assertTrue(isinstance(r, tuple))
1370            self.assertEqual(len(r), 3)
1371            self.assertIsInstance(r[0], str)
1372            self.assertIsInstance(r[1], int)
1373            self.assertIsInstance(r[2], str)
1374            self.assertEqual(r[0], 'test_notify')
1375            self.assertEqual(r[2], 'test_payload')
1376            self.assertIsNone(getnotify())
1377        finally:
1378            query('unlisten test_notify')
1379
1380    def testGetNoticeReceiver(self):
1381        self.assertIsNone(self.c.get_notice_receiver())
1382
1383    def testSetNoticeReceiver(self):
1384        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
1385        self.assertRaises(TypeError, self.c.set_notice_receiver, 'invalid')
1386        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
1387        self.assertIsNone(self.c.set_notice_receiver(None))
1388
1389    def testSetAndGetNoticeReceiver(self):
1390        r = lambda notice: None
1391        self.assertIsNone(self.c.set_notice_receiver(r))
1392        self.assertIs(self.c.get_notice_receiver(), r)
1393        self.assertIsNone(self.c.set_notice_receiver(None))
1394        self.assertIsNone(self.c.get_notice_receiver())
1395
1396    def testNoticeReceiver(self):
1397        self.addCleanup(self.c.query, 'drop function bilbo_notice();')
1398        self.c.query('''create function bilbo_notice() returns void AS $$
1399            begin
1400                raise warning 'Bilbo was here!';
1401            end;
1402            $$ language plpgsql''')
1403        received = {}
1404
1405        def notice_receiver(notice):
1406            for attr in dir(notice):
1407                if attr.startswith('__'):
1408                    continue
1409                value = getattr(notice, attr)
1410                if isinstance(value, str):
1411                    value = value.replace('WARNUNG', 'WARNING')
1412                received[attr] = value
1413
1414        self.c.set_notice_receiver(notice_receiver)
1415        self.c.query('select bilbo_notice()')
1416        self.assertEqual(received, dict(
1417            pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
1418            severity='WARNING', primary='Bilbo was here!',
1419            detail=None, hint=None))
1420
1421
1422class TestConfigFunctions(unittest.TestCase):
1423    """Test the functions for changing default settings.
1424
1425    To test the effect of most of these functions, we need a database
1426    connection.  That's why they are covered in this test module.
1427    """
1428
1429    def setUp(self):
1430        self.c = connect()
1431        self.c.query("set client_encoding=utf8")
1432        self.c.query('set bytea_output=hex')
1433        self.c.query("set lc_monetary='C'")
1434
1435    def tearDown(self):
1436        self.c.close()
1437
1438    def testGetDecimalPoint(self):
1439        point = pg.get_decimal_point()
1440        # error if a parameter is passed
1441        self.assertRaises(TypeError, pg.get_decimal_point, point)
1442        self.assertIsInstance(point, str)
1443        self.assertEqual(point, '.')  # the default setting
1444        pg.set_decimal_point(',')
1445        try:
1446            r = pg.get_decimal_point()
1447        finally:
1448            pg.set_decimal_point(point)
1449        self.assertIsInstance(r, str)
1450        self.assertEqual(r, ',')
1451        pg.set_decimal_point("'")
1452        try:
1453            r = pg.get_decimal_point()
1454        finally:
1455            pg.set_decimal_point(point)
1456        self.assertIsInstance(r, str)
1457        self.assertEqual(r, "'")
1458        pg.set_decimal_point('')
1459        try:
1460            r = pg.get_decimal_point()
1461        finally:
1462            pg.set_decimal_point(point)
1463        self.assertIsNone(r)
1464        pg.set_decimal_point(None)
1465        try:
1466            r = pg.get_decimal_point()
1467        finally:
1468            pg.set_decimal_point(point)
1469        self.assertIsNone(r)
1470
1471    def testSetDecimalPoint(self):
1472        d = pg.Decimal
1473        point = pg.get_decimal_point()
1474        self.assertRaises(TypeError, pg.set_decimal_point)
1475        # error if decimal point is not a string
1476        self.assertRaises(TypeError, pg.set_decimal_point, 0)
1477        # error if more than one decimal point passed
1478        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
1479        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
1480        # error if decimal point is not a punctuation character
1481        self.assertRaises(TypeError, pg.set_decimal_point, '0')
1482        query = self.c.query
1483        # check that money values are interpreted as decimal values
1484        # only if decimal_point is set, and that the result is correct
1485        # only if it is set suitable for the current lc_monetary setting
1486        select_money = "select '34.25'::money"
1487        proper_money = d('34.25')
1488        bad_money = d('3425')
1489        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
1490        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
1491        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
1492        de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25',
1493            'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM')
1494        # first try with English localization (using the point)
1495        for lc in en_locales:
1496            try:
1497                query("set lc_monetary='%s'" % lc)
1498            except pg.DataError:
1499                pass
1500            else:
1501                break
1502        else:
1503            self.skipTest("cannot set English money locale")
1504        try:
1505            r = query(select_money)
1506        except pg.DataError:
1507            # this can happen if the currency signs cannot be
1508            # converted using the encoding of the test database
1509            self.skipTest("database does not support English money")
1510        pg.set_decimal_point(None)
1511        try:
1512            r = r.getresult()[0][0]
1513        finally:
1514            pg.set_decimal_point(point)
1515        self.assertIsInstance(r, str)
1516        self.assertIn(r, en_money)
1517        r = query(select_money)
1518        pg.set_decimal_point('')
1519        try:
1520            r = r.getresult()[0][0]
1521        finally:
1522            pg.set_decimal_point(point)
1523        self.assertIsInstance(r, str)
1524        self.assertIn(r, en_money)
1525        r = query(select_money)
1526        pg.set_decimal_point('.')
1527        try:
1528            r = r.getresult()[0][0]
1529        finally:
1530            pg.set_decimal_point(point)
1531        self.assertIsInstance(r, d)
1532        self.assertEqual(r, proper_money)
1533        r = query(select_money)
1534        pg.set_decimal_point(',')
1535        try:
1536            r = r.getresult()[0][0]
1537        finally:
1538            pg.set_decimal_point(point)
1539        self.assertIsInstance(r, d)
1540        self.assertEqual(r, bad_money)
1541        r = query(select_money)
1542        pg.set_decimal_point("'")
1543        try:
1544            r = r.getresult()[0][0]
1545        finally:
1546            pg.set_decimal_point(point)
1547        self.assertIsInstance(r, d)
1548        self.assertEqual(r, bad_money)
1549        # then try with German localization (using the comma)
1550        for lc in de_locales:
1551            try:
1552                query("set lc_monetary='%s'" % lc)
1553            except pg.DataError:
1554                pass
1555            else:
1556                break
1557        else:
1558            self.skipTest("cannot set German money locale")
1559        select_money = select_money.replace('.', ',')
1560        try:
1561            r = query(select_money)
1562        except pg.DataError:
1563            self.skipTest("database does not support English money")
1564        pg.set_decimal_point(None)
1565        try:
1566            r = r.getresult()[0][0]
1567        finally:
1568            pg.set_decimal_point(point)
1569        self.assertIsInstance(r, str)
1570        self.assertIn(r, de_money)
1571        r = query(select_money)
1572        pg.set_decimal_point('')
1573        try:
1574            r = r.getresult()[0][0]
1575        finally:
1576            pg.set_decimal_point(point)
1577        self.assertIsInstance(r, str)
1578        self.assertIn(r, de_money)
1579        r = query(select_money)
1580        pg.set_decimal_point(',')
1581        try:
1582            r = r.getresult()[0][0]
1583        finally:
1584            pg.set_decimal_point(point)
1585        self.assertIsInstance(r, d)
1586        self.assertEqual(r, proper_money)
1587        r = query(select_money)
1588        pg.set_decimal_point('.')
1589        try:
1590            r = r.getresult()[0][0]
1591        finally:
1592            pg.set_decimal_point(point)
1593        self.assertEqual(r, bad_money)
1594        r = query(select_money)
1595        pg.set_decimal_point("'")
1596        try:
1597            r = r.getresult()[0][0]
1598        finally:
1599            pg.set_decimal_point(point)
1600        self.assertEqual(r, bad_money)
1601
1602    def testGetDecimal(self):
1603        decimal_class = pg.get_decimal()
1604        # error if a parameter is passed
1605        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
1606        self.assertIs(decimal_class, pg.Decimal)  # the default setting
1607        pg.set_decimal(int)
1608        try:
1609            r = pg.get_decimal()
1610        finally:
1611            pg.set_decimal(decimal_class)
1612        self.assertIs(r, int)
1613        r = pg.get_decimal()
1614        self.assertIs(r, decimal_class)
1615
1616    def testSetDecimal(self):
1617        decimal_class = pg.get_decimal()
1618        # error if no parameter is passed
1619        self.assertRaises(TypeError, pg.set_decimal)
1620        query = self.c.query
1621        try:
1622            r = query("select 3425::numeric")
1623        except pg.DatabaseError:
1624            self.skipTest('database does not support numeric')
1625        r = r.getresult()[0][0]
1626        self.assertIsInstance(r, decimal_class)
1627        self.assertEqual(r, decimal_class('3425'))
1628        r = query("select 3425::numeric")
1629        pg.set_decimal(int)
1630        try:
1631            r = r.getresult()[0][0]
1632        finally:
1633            pg.set_decimal(decimal_class)
1634        self.assertNotIsInstance(r, decimal_class)
1635        self.assertIsInstance(r, int)
1636        self.assertEqual(r, int(3425))
1637
1638    def testGetBool(self):
1639        use_bool = pg.get_bool()
1640        # error if a parameter is passed
1641        self.assertRaises(TypeError, pg.get_bool, use_bool)
1642        self.assertIsInstance(use_bool, bool)
1643        self.assertIs(use_bool, True)  # the default setting
1644        pg.set_bool(False)
1645        try:
1646            r = pg.get_bool()
1647        finally:
1648            pg.set_bool(use_bool)
1649        self.assertIsInstance(r, bool)
1650        self.assertIs(r, False)
1651        pg.set_bool(True)
1652        try:
1653            r = pg.get_bool()
1654        finally:
1655            pg.set_bool(use_bool)
1656        self.assertIsInstance(r, bool)
1657        self.assertIs(r, True)
1658        pg.set_bool(0)
1659        try:
1660            r = pg.get_bool()
1661        finally:
1662            pg.set_bool(use_bool)
1663        self.assertIsInstance(r, bool)
1664        self.assertIs(r, False)
1665        pg.set_bool(1)
1666        try:
1667            r = pg.get_bool()
1668        finally:
1669            pg.set_bool(use_bool)
1670        self.assertIsInstance(r, bool)
1671        self.assertIs(r, True)
1672
1673    def testSetBool(self):
1674        use_bool = pg.get_bool()
1675        # error if no parameter is passed
1676        self.assertRaises(TypeError, pg.set_bool)
1677        query = self.c.query
1678        try:
1679            r = query("select true::bool")
1680        except pg.ProgrammingError:
1681            self.skipTest('database does not support bool')
1682        r = r.getresult()[0][0]
1683        self.assertIsInstance(r, bool)
1684        self.assertEqual(r, True)
1685        r = query("select true::bool")
1686        pg.set_bool(False)
1687        try:
1688            r = r.getresult()[0][0]
1689        finally:
1690            pg.set_bool(use_bool)
1691        self.assertIsInstance(r, str)
1692        self.assertIs(r, 't')
1693        r = query("select true::bool")
1694        pg.set_bool(True)
1695        try:
1696            r = r.getresult()[0][0]
1697        finally:
1698            pg.set_bool(use_bool)
1699        self.assertIsInstance(r, bool)
1700        self.assertIs(r, True)
1701
1702    def testGetByteEscaped(self):
1703        bytea_escaped = pg.get_bytea_escaped()
1704        # error if a parameter is passed
1705        self.assertRaises(TypeError, pg.get_bytea_escaped, bytea_escaped)
1706        self.assertIsInstance(bytea_escaped, bool)
1707        self.assertIs(bytea_escaped, False)  # the default setting
1708        pg.set_bytea_escaped(True)
1709        try:
1710            r = pg.get_bytea_escaped()
1711        finally:
1712            pg.set_bytea_escaped(bytea_escaped)
1713        self.assertIsInstance(r, bool)
1714        self.assertIs(r, True)
1715        pg.set_bytea_escaped(False)
1716        try:
1717            r = pg.get_bytea_escaped()
1718        finally:
1719            pg.set_bytea_escaped(bytea_escaped)
1720        self.assertIsInstance(r, bool)
1721        self.assertIs(r, False)
1722        pg.set_bytea_escaped(1)
1723        try:
1724            r = pg.get_bytea_escaped()
1725        finally:
1726            pg.set_bytea_escaped(bytea_escaped)
1727        self.assertIsInstance(r, bool)
1728        self.assertIs(r, True)
1729        pg.set_bytea_escaped(0)
1730        try:
1731            r = pg.get_bytea_escaped()
1732        finally:
1733            pg.set_bytea_escaped(bytea_escaped)
1734        self.assertIsInstance(r, bool)
1735        self.assertIs(r, False)
1736
1737    def testSetByteaEscaped(self):
1738        bytea_escaped = pg.get_bytea_escaped()
1739        # error if no parameter is passed
1740        self.assertRaises(TypeError, pg.set_bytea_escaped)
1741        query = self.c.query
1742        try:
1743            r = query("select 'data'::bytea")
1744        except pg.ProgrammingError:
1745            self.skipTest('database does not support bytea')
1746        r = r.getresult()[0][0]
1747        self.assertIsInstance(r, bytes)
1748        self.assertEqual(r, b'data')
1749        r = query("select 'data'::bytea")
1750        pg.set_bytea_escaped(True)
1751        try:
1752            r = r.getresult()[0][0]
1753        finally:
1754            pg.set_bytea_escaped(bytea_escaped)
1755        self.assertIsInstance(r, str)
1756        self.assertEqual(r, '\\x64617461')
1757        r = query("select 'data'::bytea")
1758        pg.set_bytea_escaped(False)
1759        try:
1760            r = r.getresult()[0][0]
1761        finally:
1762            pg.set_bytea_escaped(bytea_escaped)
1763        self.assertIsInstance(r, bytes)
1764        self.assertEqual(r, b'data')
1765
1766    def testGetNamedresult(self):
1767        namedresult = pg.get_namedresult()
1768        # error if a parameter is passed
1769        self.assertRaises(TypeError, pg.get_namedresult, namedresult)
1770        self.assertIs(namedresult, pg._namedresult)  # the default setting
1771
1772    def testSetNamedresult(self):
1773        namedresult = pg.get_namedresult()
1774        self.assertTrue(callable(namedresult))
1775
1776        query = self.c.query
1777
1778        r = query("select 1 as x, 2 as y").namedresult()[0]
1779        self.assertIsInstance(r, tuple)
1780        self.assertEqual(r, (1, 2))
1781        self.assertIsNot(type(r), tuple)
1782        self.assertEqual(r._fields, ('x', 'y'))
1783        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1784        self.assertEqual(r.__class__.__name__, 'Row')
1785
1786        def listresult(q):
1787            return [list(row) for row in q.getresult()]
1788
1789        pg.set_namedresult(listresult)
1790        try:
1791            r = pg.get_namedresult()
1792            self.assertIs(r, listresult)
1793            r = query("select 1 as x, 2 as y").namedresult()[0]
1794            self.assertIsInstance(r, list)
1795            self.assertEqual(r, [1, 2])
1796            self.assertIsNot(type(r), tuple)
1797            self.assertFalse(hasattr(r, '_fields'))
1798            self.assertNotEqual(r.__class__.__name__, 'Row')
1799        finally:
1800            pg.set_namedresult(namedresult)
1801
1802        r = pg.get_namedresult()
1803        self.assertIs(r, namedresult)
1804
1805
1806class TestStandaloneEscapeFunctions(unittest.TestCase):
1807    """Test pg escape functions.
1808
1809    The libpq interface memorizes some parameters of the last opened
1810    connection that influence the result of these functions.  Therefore
1811    we need to open a connection with fixed parameters prior to testing
1812    in order to ensure that the tests always run under the same conditions.
1813    That's why these tests are included in this test module.
1814    """
1815
1816    cls_set_up = False
1817
1818    @classmethod
1819    def setUpClass(cls):
1820        query = connect().query
1821        query('set client_encoding=sql_ascii')
1822        query('set standard_conforming_strings=off')
1823        query('set bytea_output=escape')
1824        cls.cls_set_up = True
1825
1826    def testEscapeString(self):
1827        self.assertTrue(self.cls_set_up)
1828        f = pg.escape_string
1829        r = f(b'plain')
1830        self.assertIsInstance(r, bytes)
1831        self.assertEqual(r, b'plain')
1832        r = f(u'plain')
1833        self.assertIsInstance(r, unicode)
1834        self.assertEqual(r, u'plain')
1835        r = f(u"das is' kÀse".encode('utf-8'))
1836        self.assertIsInstance(r, bytes)
1837        self.assertEqual(r, u"das is'' kÀse".encode('utf-8'))
1838        r = f(u"that's cheesy")
1839        self.assertIsInstance(r, unicode)
1840        self.assertEqual(r, u"that''s cheesy")
1841        r = f(r"It's bad to have a \ inside.")
1842        self.assertEqual(r, r"It''s bad to have a \\ inside.")
1843
1844    def testEscapeBytea(self):
1845        self.assertTrue(self.cls_set_up)
1846        f = pg.escape_bytea
1847        r = f(b'plain')
1848        self.assertIsInstance(r, bytes)
1849        self.assertEqual(r, b'plain')
1850        r = f(u'plain')
1851        self.assertIsInstance(r, unicode)
1852        self.assertEqual(r, u'plain')
1853        r = f(u"das is' kÀse".encode('utf-8'))
1854        self.assertIsInstance(r, bytes)
1855        self.assertEqual(r, b"das is'' k\\\\303\\\\244se")
1856        r = f(u"that's cheesy")
1857        self.assertIsInstance(r, unicode)
1858        self.assertEqual(r, u"that''s cheesy")
1859        r = f(b'O\x00ps\xff!')
1860        self.assertEqual(r, b'O\\\\000ps\\\\377!')
1861
1862
1863if __name__ == '__main__':
1864    unittest.main()
Note: See TracBrowser for help on using the repository browser.