source: trunk/tests/test_classic_connection.py @ 798

Last change on this file since 798 was 798, checked in by cito, 3 years ago

Port type cache and typecasting from pgdb to pg

So far, the typecasting in the classic module was been only done by
the C extension module and was not extensible through typecasting
functions in Python. This has now been made extensible by adding
a cast hook to the C extension module which has been hooked up to
a new type cache object that holds information on the types and the
associated typecast functions. All of this works very similar to the
pgdb module now, except that the basic types are still handled by
the C extension module and the Python typecast functions are only
called via the hook for types which are not supported internally.

Also added tests and a chapter on the type cache in the documentation,
and cleaned up the error messages in the C extension module.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 62.1 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17import threading
18import time
19import os
20
21import pg  # the module under test
22
23from decimal import Decimal
24
25# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
26# get our information from that.  Otherwise we use the defaults.
27# These tests should be run with various PostgreSQL versions and databases
28# created with different encodings and locales.  Particularly, make sure the
29# tests are running against databases created with both SQL_ASCII and UTF8.
30dbname = 'unittest'
31dbhost = None
32dbport = 5432
33
34try:
35    from .LOCAL_PyGreSQL import *
36except (ImportError, ValueError):
37    try:
38        from LOCAL_PyGreSQL import *
39    except ImportError:
40        pass
41
42try:
43    long
44except NameError:  # Python >= 3.0
45    long = int
46
47try:
48    unicode
49except NameError:  # Python >= 3.0
50    unicode = str
51
52unicode_strings = str is not bytes
53
54windows = os.name == 'nt'
55
56# There is a known a bug in libpq under Windows which can cause
57# the interface to crash when calling PQhost():
58do_not_ask_for_host = windows
59do_not_ask_for_host_reason = 'libpq issue on Windows'
60
61
62def connect():
63    """Create a basic pg connection to the test database."""
64    connection = pg.connect(dbname, dbhost, dbport)
65    connection.query("set client_min_messages=warning")
66    return connection
67
68
69class TestCanConnect(unittest.TestCase):
70    """Test whether a basic connection to PostgreSQL is possible."""
71
72    def testCanConnect(self):
73        try:
74            connection = connect()
75        except pg.Error as error:
76            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
77        try:
78            connection.close()
79        except pg.Error:
80            self.fail('Cannot close the database connection')
81
82
83class TestConnectObject(unittest.TestCase):
84    """Test existence of basic pg connection methods."""
85
86    def setUp(self):
87        self.connection = connect()
88
89    def tearDown(self):
90        try:
91            self.connection.close()
92        except pg.InternalError:
93            pass
94
95    def is_method(self, attribute):
96        """Check if given attribute on the connection is a method."""
97        if do_not_ask_for_host and attribute == 'host':
98            return False
99        return callable(getattr(self.connection, attribute))
100
101    def testClassName(self):
102        self.assertEqual(self.connection.__class__.__name__, 'Connection')
103
104    def testModuleName(self):
105        self.assertEqual(self.connection.__class__.__module__, 'pg')
106
107    def testStr(self):
108        r = str(self.connection)
109        self.assertTrue(r.startswith('<pg.Connection object'), r)
110
111    def testRepr(self):
112        r = repr(self.connection)
113        self.assertTrue(r.startswith('<pg.Connection object'), r)
114
115    def testAllConnectAttributes(self):
116        attributes = '''db error host options port
117            protocol_version server_version status user'''.split()
118        connection_attributes = [a for a in dir(self.connection)
119            if not a.startswith('__') and not self.is_method(a)]
120        self.assertEqual(attributes, connection_attributes)
121
122    def testAllConnectMethods(self):
123        methods = '''cancel close endcopy
124            escape_bytea escape_identifier escape_literal escape_string
125            fileno get_cast_hook get_notice_receiver getline getlo getnotify
126            inserttable locreate loimport parameter putline query reset
127            set_cast_hook set_notice_receiver source transaction'''.split()
128        connection_methods = [a for a in dir(self.connection)
129            if not a.startswith('__') and self.is_method(a)]
130        self.assertEqual(methods, connection_methods)
131
132    def testAttributeDb(self):
133        self.assertEqual(self.connection.db, dbname)
134
135    def testAttributeError(self):
136        error = self.connection.error
137        self.assertTrue(not error or 'krb5_' in error)
138
139    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
140    def testAttributeHost(self):
141        def_host = 'localhost'
142        self.assertIsInstance(self.connection.host, str)
143        self.assertEqual(self.connection.host, dbhost or def_host)
144
145    def testAttributeOptions(self):
146        no_options = ''
147        self.assertEqual(self.connection.options, no_options)
148
149    def testAttributePort(self):
150        def_port = 5432
151        self.assertIsInstance(self.connection.port, int)
152        self.assertEqual(self.connection.port, dbport or def_port)
153
154    def testAttributeProtocolVersion(self):
155        protocol_version = self.connection.protocol_version
156        self.assertIsInstance(protocol_version, int)
157        self.assertTrue(2 <= protocol_version < 4)
158
159    def testAttributeServerVersion(self):
160        server_version = self.connection.server_version
161        self.assertIsInstance(server_version, int)
162        self.assertTrue(70400 <= server_version < 100000)
163
164    def testAttributeStatus(self):
165        status_ok = 1
166        self.assertIsInstance(self.connection.status, int)
167        self.assertEqual(self.connection.status, status_ok)
168
169    def testAttributeUser(self):
170        no_user = 'Deprecated facility'
171        user = self.connection.user
172        self.assertTrue(user)
173        self.assertIsInstance(user, str)
174        self.assertNotEqual(user, no_user)
175
176    def testMethodQuery(self):
177        query = self.connection.query
178        query("select 1+1")
179        query("select 1+$1", (1,))
180        query("select 1+$1+$2", (2, 3))
181        query("select 1+$1+$2", [2, 3])
182
183    def testMethodQueryEmpty(self):
184        self.assertRaises(ValueError, self.connection.query, '')
185
186    def testMethodEndcopy(self):
187        try:
188            self.connection.endcopy()
189        except IOError:
190            pass
191
192    def testMethodClose(self):
193        self.connection.close()
194        try:
195            self.connection.reset()
196        except (pg.Error, TypeError):
197            pass
198        else:
199            self.fail('Reset should give an error for a closed connection')
200        self.assertRaises(pg.InternalError, self.connection.close)
201        try:
202            self.connection.query('select 1')
203        except (pg.Error, TypeError):
204            pass
205        else:
206            self.fail('Query should give an error for a closed connection')
207        self.connection = connect()
208
209    def testMethodReset(self):
210        query = self.connection.query
211        # check that client encoding gets reset
212        encoding = query('show client_encoding').getresult()[0][0].upper()
213        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
214        self.assertNotEqual(encoding, changed_encoding)
215        self.connection.query("set client_encoding=%s" % changed_encoding)
216        new_encoding = query('show client_encoding').getresult()[0][0].upper()
217        self.assertEqual(new_encoding, changed_encoding)
218        self.connection.reset()
219        new_encoding = query('show client_encoding').getresult()[0][0].upper()
220        self.assertNotEqual(new_encoding, changed_encoding)
221        self.assertEqual(new_encoding, encoding)
222
223    def testMethodCancel(self):
224        r = self.connection.cancel()
225        self.assertIsInstance(r, int)
226        self.assertEqual(r, 1)
227
228    def testCancelLongRunningThread(self):
229        errors = []
230
231        def sleep():
232            try:
233                self.connection.query('select pg_sleep(5)').getresult()
234            except pg.ProgrammingError as error:
235                errors.append(str(error))
236
237        thread = threading.Thread(target=sleep)
238        t1 = time.time()
239        thread.start()  # run the query
240        while 1:  # make sure the query is really running
241            time.sleep(0.1)
242            if thread.is_alive() or time.time() - t1 > 5:
243                break
244        r = self.connection.cancel()  # cancel the running query
245        thread.join()  # wait for the thread to end
246        t2 = time.time()
247
248        self.assertIsInstance(r, int)
249        self.assertEqual(r, 1)  # return code should be 1
250        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
251        self.assertTrue(errors)
252
253    def testMethodFileNo(self):
254        r = self.connection.fileno()
255        self.assertIsInstance(r, int)
256        self.assertGreaterEqual(r, 0)
257
258
259class TestSimpleQueries(unittest.TestCase):
260    """Test simple queries via a basic pg connection."""
261
262    def setUp(self):
263        self.c = connect()
264
265    def tearDown(self):
266        self.doCleanups()
267        self.c.close()
268
269    def testClassName(self):
270        r = self.c.query("select 1")
271        self.assertEqual(r.__class__.__name__, 'Query')
272
273    def testModuleName(self):
274        r = self.c.query("select 1")
275        self.assertEqual(r.__class__.__module__, 'pg')
276
277    def testStr(self):
278        q = ("select 1 as a, 'hello' as h, 'w' as world"
279            " union select 2, 'xyz', 'uvw'")
280        r = self.c.query(q)
281        self.assertEqual(str(r),
282            'a|  h  |world\n'
283            '-+-----+-----\n'
284            '1|hello|w    \n'
285            '2|xyz  |uvw  \n'
286            '(2 rows)')
287
288    def testRepr(self):
289        r = repr(self.c.query("select 1"))
290        self.assertTrue(r.startswith('<pg.Query object'), r)
291
292    def testSelect0(self):
293        q = "select 0"
294        self.c.query(q)
295
296    def testSelect0Semicolon(self):
297        q = "select 0;"
298        self.c.query(q)
299
300    def testSelectDotSemicolon(self):
301        q = "select .;"
302        self.assertRaises(pg.ProgrammingError, self.c.query, q)
303
304    def testGetresult(self):
305        q = "select 0"
306        result = [(0,)]
307        r = self.c.query(q).getresult()
308        self.assertIsInstance(r, list)
309        v = r[0]
310        self.assertIsInstance(v, tuple)
311        self.assertIsInstance(v[0], int)
312        self.assertEqual(r, result)
313
314    def testGetresultLong(self):
315        q = "select 9876543210"
316        result = long(9876543210)
317        self.assertIsInstance(result, long)
318        v = self.c.query(q).getresult()[0][0]
319        self.assertIsInstance(v, long)
320        self.assertEqual(v, result)
321
322    def testGetresultDecimal(self):
323        q = "select 98765432109876543210"
324        result = Decimal(98765432109876543210)
325        v = self.c.query(q).getresult()[0][0]
326        self.assertIsInstance(v, Decimal)
327        self.assertEqual(v, result)
328
329    def testGetresultString(self):
330        result = 'Hello, world!'
331        q = "select '%s'" % result
332        v = self.c.query(q).getresult()[0][0]
333        self.assertIsInstance(v, str)
334        self.assertEqual(v, result)
335
336    def testDictresult(self):
337        q = "select 0 as alias0"
338        result = [{'alias0': 0}]
339        r = self.c.query(q).dictresult()
340        self.assertIsInstance(r, list)
341        v = r[0]
342        self.assertIsInstance(v, dict)
343        self.assertIsInstance(v['alias0'], int)
344        self.assertEqual(r, result)
345
346    def testDictresultLong(self):
347        q = "select 9876543210 as longjohnsilver"
348        result = long(9876543210)
349        self.assertIsInstance(result, long)
350        v = self.c.query(q).dictresult()[0]['longjohnsilver']
351        self.assertIsInstance(v, long)
352        self.assertEqual(v, result)
353
354    def testDictresultDecimal(self):
355        q = "select 98765432109876543210 as longjohnsilver"
356        result = Decimal(98765432109876543210)
357        v = self.c.query(q).dictresult()[0]['longjohnsilver']
358        self.assertIsInstance(v, Decimal)
359        self.assertEqual(v, result)
360
361    def testDictresultString(self):
362        result = 'Hello, world!'
363        q = "select '%s' as greeting" % result
364        v = self.c.query(q).dictresult()[0]['greeting']
365        self.assertIsInstance(v, str)
366        self.assertEqual(v, result)
367
368    def testNamedresult(self):
369        q = "select 0 as alias0"
370        result = [(0,)]
371        r = self.c.query(q).namedresult()
372        self.assertEqual(r, result)
373        v = r[0]
374        self.assertEqual(v._fields, ('alias0',))
375        self.assertEqual(v.alias0, 0)
376
377    def testGet3Cols(self):
378        q = "select 1,2,3"
379        result = [(1, 2, 3)]
380        r = self.c.query(q).getresult()
381        self.assertEqual(r, result)
382
383    def testGet3DictCols(self):
384        q = "select 1 as a,2 as b,3 as c"
385        result = [dict(a=1, b=2, c=3)]
386        r = self.c.query(q).dictresult()
387        self.assertEqual(r, result)
388
389    def testGet3NamedCols(self):
390        q = "select 1 as a,2 as b,3 as c"
391        result = [(1, 2, 3)]
392        r = self.c.query(q).namedresult()
393        self.assertEqual(r, result)
394        v = r[0]
395        self.assertEqual(v._fields, ('a', 'b', 'c'))
396        self.assertEqual(v.b, 2)
397
398    def testGet3Rows(self):
399        q = "select 3 union select 1 union select 2 order by 1"
400        result = [(1,), (2,), (3,)]
401        r = self.c.query(q).getresult()
402        self.assertEqual(r, result)
403
404    def testGet3DictRows(self):
405        q = ("select 3 as alias3"
406            " union select 1 union select 2 order by 1")
407        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
408        r = self.c.query(q).dictresult()
409        self.assertEqual(r, result)
410
411    def testGet3NamedRows(self):
412        q = ("select 3 as alias3"
413            " union select 1 union select 2 order by 1")
414        result = [(1,), (2,), (3,)]
415        r = self.c.query(q).namedresult()
416        self.assertEqual(r, result)
417        for v in r:
418            self.assertEqual(v._fields, ('alias3',))
419
420    def testDictresultNames(self):
421        q = "select 'MixedCase' as MixedCaseAlias"
422        result = [{'mixedcasealias': 'MixedCase'}]
423        r = self.c.query(q).dictresult()
424        self.assertEqual(r, result)
425        q = "select 'MixedCase' as \"MixedCaseAlias\""
426        result = [{'MixedCaseAlias': 'MixedCase'}]
427        r = self.c.query(q).dictresult()
428        self.assertEqual(r, result)
429
430    def testNamedresultNames(self):
431        q = "select 'MixedCase' as MixedCaseAlias"
432        result = [('MixedCase',)]
433        r = self.c.query(q).namedresult()
434        self.assertEqual(r, result)
435        v = r[0]
436        self.assertEqual(v._fields, ('mixedcasealias',))
437        self.assertEqual(v.mixedcasealias, 'MixedCase')
438        q = "select 'MixedCase' as \"MixedCaseAlias\""
439        r = self.c.query(q).namedresult()
440        self.assertEqual(r, result)
441        v = r[0]
442        self.assertEqual(v._fields, ('MixedCaseAlias',))
443        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
444
445    def testBigGetresult(self):
446        num_cols = 100
447        num_rows = 100
448        q = "select " + ','.join(map(str, range(num_cols)))
449        q = ' union all '.join((q,) * num_rows)
450        r = self.c.query(q).getresult()
451        result = [tuple(range(num_cols))] * num_rows
452        self.assertEqual(r, result)
453
454    def testListfields(self):
455        q = ('select 0 as a, 0 as b, 0 as c,'
456            ' 0 as c, 0 as b, 0 as a,'
457            ' 0 as lowercase, 0 as UPPERCASE,'
458            ' 0 as MixedCase, 0 as "MixedCase",'
459            ' 0 as a_long_name_with_underscores,'
460            ' 0 as "A long name with Blanks"')
461        r = self.c.query(q).listfields()
462        result = ('a', 'b', 'c', 'c', 'b', 'a',
463            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
464            'a_long_name_with_underscores',
465            'A long name with Blanks')
466        self.assertEqual(r, result)
467
468    def testFieldname(self):
469        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
470        r = self.c.query(q).fieldname(2)
471        self.assertEqual(r, 'x')
472        r = self.c.query(q).fieldname(3)
473        self.assertEqual(r, 'y')
474
475    def testFieldnum(self):
476        q = "select 1 as x"
477        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
478        q = "select 1 as x"
479        r = self.c.query(q).fieldnum('x')
480        self.assertIsInstance(r, int)
481        self.assertEqual(r, 0)
482        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
483        r = self.c.query(q).fieldnum('x')
484        self.assertIsInstance(r, int)
485        self.assertEqual(r, 2)
486        r = self.c.query(q).fieldnum('y')
487        self.assertIsInstance(r, int)
488        self.assertEqual(r, 3)
489
490    def testNtuples(self):
491        q = "select 1 where false"
492        r = self.c.query(q).ntuples()
493        self.assertIsInstance(r, int)
494        self.assertEqual(r, 0)
495        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
496            " union select 5 as a, 6 as b, 7 as c, 8 as d")
497        r = self.c.query(q).ntuples()
498        self.assertIsInstance(r, int)
499        self.assertEqual(r, 2)
500        q = ("select 1 union select 2 union select 3"
501            " union select 4 union select 5 union select 6")
502        r = self.c.query(q).ntuples()
503        self.assertIsInstance(r, int)
504        self.assertEqual(r, 6)
505
506    def testQuery(self):
507        query = self.c.query
508        query("drop table if exists test_table")
509        self.addCleanup(query, "drop table test_table")
510        q = "create table test_table (n integer) with oids"
511        r = query(q)
512        self.assertIsNone(r)
513        q = "insert into test_table values (1)"
514        r = query(q)
515        self.assertIsInstance(r, int)
516        q = "insert into test_table select 2"
517        r = query(q)
518        self.assertIsInstance(r, int)
519        oid = r
520        q = "select oid from test_table where n=2"
521        r = query(q).getresult()
522        self.assertEqual(len(r), 1)
523        r = r[0]
524        self.assertEqual(len(r), 1)
525        r = r[0]
526        self.assertIsInstance(r, int)
527        self.assertEqual(r, oid)
528        q = "insert into test_table select 3 union select 4 union select 5"
529        r = query(q)
530        self.assertIsInstance(r, str)
531        self.assertEqual(r, '3')
532        q = "update test_table set n=4 where n<5"
533        r = query(q)
534        self.assertIsInstance(r, str)
535        self.assertEqual(r, '4')
536        q = "delete from test_table"
537        r = query(q)
538        self.assertIsInstance(r, str)
539        self.assertEqual(r, '5')
540
541
542class TestUnicodeQueries(unittest.TestCase):
543    """Test unicode strings as queries via a basic pg connection."""
544
545    def setUp(self):
546        self.c = connect()
547        self.c.query('set client_encoding=utf8')
548
549    def tearDown(self):
550        self.c.close()
551
552    def testGetresulAscii(self):
553        result = u'Hello, world!'
554        q = u"select '%s'" % result
555        v = self.c.query(q).getresult()[0][0]
556        self.assertIsInstance(v, str)
557        self.assertEqual(v, result)
558
559    def testDictresulAscii(self):
560        result = u'Hello, world!'
561        q = u"select '%s' as greeting" % result
562        v = self.c.query(q).dictresult()[0]['greeting']
563        self.assertIsInstance(v, str)
564        self.assertEqual(v, result)
565
566    def testGetresultUtf8(self):
567        result = u'Hello, wörld & ЌОр!'
568        q = u"select '%s'" % result
569        if not unicode_strings:
570            result = result.encode('utf8')
571        # pass the query as unicode
572        try:
573            v = self.c.query(q).getresult()[0][0]
574        except pg.ProgrammingError:
575            self.skipTest("database does not support utf8")
576        self.assertIsInstance(v, str)
577        self.assertEqual(v, result)
578        q = q.encode('utf8')
579        # pass the query as bytes
580        v = self.c.query(q).getresult()[0][0]
581        self.assertIsInstance(v, str)
582        self.assertEqual(v, result)
583
584    def testDictresultUtf8(self):
585        result = u'Hello, wörld & ЌОр!'
586        q = u"select '%s' as greeting" % result
587        if not unicode_strings:
588            result = result.encode('utf8')
589        try:
590            v = self.c.query(q).dictresult()[0]['greeting']
591        except pg.ProgrammingError:
592            self.skipTest("database does not support utf8")
593        self.assertIsInstance(v, str)
594        self.assertEqual(v, result)
595        q = q.encode('utf8')
596        v = self.c.query(q).dictresult()[0]['greeting']
597        self.assertIsInstance(v, str)
598        self.assertEqual(v, result)
599
600    def testDictresultLatin1(self):
601        try:
602            self.c.query('set client_encoding=latin1')
603        except pg.ProgrammingError:
604            self.skipTest("database does not support latin1")
605        result = u'Hello, wörld!'
606        q = u"select '%s'" % result
607        if not unicode_strings:
608            result = result.encode('latin1')
609        v = self.c.query(q).getresult()[0][0]
610        self.assertIsInstance(v, str)
611        self.assertEqual(v, result)
612        q = q.encode('latin1')
613        v = self.c.query(q).getresult()[0][0]
614        self.assertIsInstance(v, str)
615        self.assertEqual(v, result)
616
617    def testDictresultLatin1(self):
618        try:
619            self.c.query('set client_encoding=latin1')
620        except pg.ProgrammingError:
621            self.skipTest("database does not support latin1")
622        result = u'Hello, wörld!'
623        q = u"select '%s' as greeting" % result
624        if not unicode_strings:
625            result = result.encode('latin1')
626        v = self.c.query(q).dictresult()[0]['greeting']
627        self.assertIsInstance(v, str)
628        self.assertEqual(v, result)
629        q = q.encode('latin1')
630        v = self.c.query(q).dictresult()[0]['greeting']
631        self.assertIsInstance(v, str)
632        self.assertEqual(v, result)
633
634    def testGetresultCyrillic(self):
635        try:
636            self.c.query('set client_encoding=iso_8859_5')
637        except pg.ProgrammingError:
638            self.skipTest("database does not support cyrillic")
639        result = u'Hello, ЌОр!'
640        q = u"select '%s'" % result
641        if not unicode_strings:
642            result = result.encode('cyrillic')
643        v = self.c.query(q).getresult()[0][0]
644        self.assertIsInstance(v, str)
645        self.assertEqual(v, result)
646        q = q.encode('cyrillic')
647        v = self.c.query(q).getresult()[0][0]
648        self.assertIsInstance(v, str)
649        self.assertEqual(v, result)
650
651    def testDictresultCyrillic(self):
652        try:
653            self.c.query('set client_encoding=iso_8859_5')
654        except pg.ProgrammingError:
655            self.skipTest("database does not support cyrillic")
656        result = u'Hello, ЌОр!'
657        q = u"select '%s' as greeting" % result
658        if not unicode_strings:
659            result = result.encode('cyrillic')
660        v = self.c.query(q).dictresult()[0]['greeting']
661        self.assertIsInstance(v, str)
662        self.assertEqual(v, result)
663        q = q.encode('cyrillic')
664        v = self.c.query(q).dictresult()[0]['greeting']
665        self.assertIsInstance(v, str)
666        self.assertEqual(v, result)
667
668    def testGetresultLatin9(self):
669        try:
670            self.c.query('set client_encoding=latin9')
671        except pg.ProgrammingError:
672            self.skipTest("database does not support latin9")
673        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
674        q = u"select '%s'" % result
675        if not unicode_strings:
676            result = result.encode('latin9')
677        v = self.c.query(q).getresult()[0][0]
678        self.assertIsInstance(v, str)
679        self.assertEqual(v, result)
680        q = q.encode('latin9')
681        v = self.c.query(q).getresult()[0][0]
682        self.assertIsInstance(v, str)
683        self.assertEqual(v, result)
684
685    def testDictresultLatin9(self):
686        try:
687            self.c.query('set client_encoding=latin9')
688        except pg.ProgrammingError:
689            self.skipTest("database does not support latin9")
690        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
691        q = u"select '%s' as menu" % result
692        if not unicode_strings:
693            result = result.encode('latin9')
694        v = self.c.query(q).dictresult()[0]['menu']
695        self.assertIsInstance(v, str)
696        self.assertEqual(v, result)
697        q = q.encode('latin9')
698        v = self.c.query(q).dictresult()[0]['menu']
699        self.assertIsInstance(v, str)
700        self.assertEqual(v, result)
701
702
703class TestParamQueries(unittest.TestCase):
704    """Test queries with parameters via a basic pg connection."""
705
706    def setUp(self):
707        self.c = connect()
708        self.c.query('set client_encoding=utf8')
709
710    def tearDown(self):
711        self.c.close()
712
713    def testQueryWithNoneParam(self):
714        self.assertRaises(TypeError, self.c.query, "select $1", None)
715        self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None)
716        self.assertEqual(self.c.query("select $1::integer", (None,)
717            ).getresult(), [(None,)])
718        self.assertEqual(self.c.query("select $1::text", [None]
719            ).getresult(), [(None,)])
720        self.assertEqual(self.c.query("select $1::text", [[None]]
721            ).getresult(), [(None,)])
722
723    def testQueryWithBoolParams(self, use_bool=None):
724        query = self.c.query
725        if use_bool is not None:
726            use_bool_default = pg.get_bool()
727            pg.set_bool(use_bool)
728        try:
729            v_false, v_true = (False, True) if use_bool else 'ft'
730            r_false, r_true = [(v_false,)], [(v_true,)]
731            self.assertEqual(query("select false").getresult(), r_false)
732            self.assertEqual(query("select true").getresult(), r_true)
733            q = "select $1::bool"
734            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
735            self.assertEqual(query(q, ('f',)).getresult(), r_false)
736            self.assertEqual(query(q, ('t',)).getresult(), r_true)
737            self.assertEqual(query(q, ('false',)).getresult(), r_false)
738            self.assertEqual(query(q, ('true',)).getresult(), r_true)
739            self.assertEqual(query(q, ('n',)).getresult(), r_false)
740            self.assertEqual(query(q, ('y',)).getresult(), r_true)
741            self.assertEqual(query(q, (0,)).getresult(), r_false)
742            self.assertEqual(query(q, (1,)).getresult(), r_true)
743            self.assertEqual(query(q, (False,)).getresult(), r_false)
744            self.assertEqual(query(q, (True,)).getresult(), r_true)
745        finally:
746            if use_bool is not None:
747                pg.set_bool(use_bool_default)
748
749    def testQueryWithBoolParamsAndUseBool(self):
750        self.testQueryWithBoolParams(use_bool=True)
751
752    def testQueryWithIntParams(self):
753        query = self.c.query
754        self.assertEqual(query("select 1+1").getresult(), [(2,)])
755        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
756        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
757        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
758        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
759        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
760            [(Decimal('2'),)])
761        self.assertEqual(query("select 1, $1::integer", (2,)
762            ).getresult(), [(1, 2)])
763        self.assertEqual(query("select 1 union select $1::integer", (2,)
764            ).getresult(), [(1,), (2,)])
765        self.assertEqual(query("select $1::integer+$2", (1, 2)
766            ).getresult(), [(3,)])
767        self.assertEqual(query("select $1::integer+$2", [1, 2]
768            ).getresult(), [(3,)])
769        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))
770            ).getresult(), [(15,)])
771
772    def testQueryWithStrParams(self):
773        query = self.c.query
774        self.assertEqual(query("select $1||', world!'", ('Hello',)
775            ).getresult(), [('Hello, world!',)])
776        self.assertEqual(query("select $1||', world!'", ['Hello']
777            ).getresult(), [('Hello, world!',)])
778        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
779            ).getresult(), [('Hello, world!',)])
780        self.assertEqual(query("select $1::text", ('Hello, world!',)
781            ).getresult(), [('Hello, world!',)])
782        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
783            ).getresult(), [('Hello', 'world')])
784        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
785            ).getresult(), [('Hello', 'world')])
786        self.assertEqual(query("select $1::text union select $2::text",
787            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
788        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
789            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
790
791    def testQueryWithUnicodeParams(self):
792        query = self.c.query
793        try:
794            query('set client_encoding=utf8')
795            query("select 'wörld'").getresult()[0][0] == 'wörld'
796        except pg.ProgrammingError:
797            self.skipTest("database does not support utf8")
798        self.assertEqual(query("select $1||', '||$2||'!'",
799            ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)])
800
801    def testQueryWithUnicodeParamsLatin1(self):
802        query = self.c.query
803        try:
804            query('set client_encoding=latin1')
805            query("select 'wörld'").getresult()[0][0] == 'wörld'
806        except pg.ProgrammingError:
807            self.skipTest("database does not support latin1")
808        r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult()
809        if unicode_strings:
810            self.assertEqual(r, [('Hello, wörld!',)])
811        else:
812            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
813        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
814            ('Hello', u'ЌОр'))
815        query('set client_encoding=iso_8859_1')
816        r = query("select $1||', '||$2||'!'",
817            ('Hello', u'wörld')).getresult()
818        if unicode_strings:
819            self.assertEqual(r, [('Hello, wörld!',)])
820        else:
821            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
822        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
823            ('Hello', u'ЌОр'))
824        query('set client_encoding=sql_ascii')
825        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
826            ('Hello', u'wörld'))
827
828    def testQueryWithUnicodeParamsCyrillic(self):
829        query = self.c.query
830        try:
831            query('set client_encoding=iso_8859_5')
832            query("select 'ЌОр'").getresult()[0][0] == 'ЌОр'
833        except pg.ProgrammingError:
834            self.skipTest("database does not support cyrillic")
835        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
836            ('Hello', u'wörld'))
837        r = query("select $1||', '||$2||'!'",
838            ('Hello', u'ЌОр')).getresult()
839        if unicode_strings:
840            self.assertEqual(r, [('Hello, ЌОр!',)])
841        else:
842            self.assertEqual(r, [(u'Hello, ЌОр!'.encode('cyrillic'),)])
843        query('set client_encoding=sql_ascii')
844        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
845            ('Hello', u'ЌОр!'))
846
847    def testQueryWithMixedParams(self):
848        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
849            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
850        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
851            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
852
853    def testQueryWithDuplicateParams(self):
854        self.assertRaises(pg.ProgrammingError,
855            self.c.query, "select $1+$1", (1,))
856        self.assertRaises(pg.ProgrammingError,
857            self.c.query, "select $1+$1", (1, 2))
858
859    def testQueryWithZeroParams(self):
860        self.assertEqual(self.c.query("select 1+1", []
861            ).getresult(), [(2,)])
862
863    def testQueryWithGarbage(self):
864        garbage = r"'\{}+()-#[]oo324"
865        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
866            ).dictresult(), [{'garbage': garbage}])
867
868
869class TestQueryResultTypes(unittest.TestCase):
870    """Test proper result types via a basic pg connection."""
871
872    def setUp(self):
873        self.c = connect()
874        self.c.query('set client_encoding=utf8')
875        self.c.query("set datestyle='ISO,YMD'")
876
877    def tearDown(self):
878        self.c.close()
879
880    def assert_proper_cast(self, value, pgtype, pytype):
881        q = 'select $1::%s' % (pgtype,)
882        r = self.c.query(q, (value,)).getresult()[0][0]
883        self.assertIsInstance(r, pytype)
884        if isinstance(value, (bytes, str)):
885            if not value or '{':
886                value = '"%s"' % value
887        value = '{%s}' % value
888        r = self.c.query(q + '[]', (value,)).getresult()[0][0]
889        self.assertIsInstance(r, list)
890        self.assertEqual(len(r), 1)
891        self.assertIsInstance(r[0], pytype)
892
893    def testInt(self):
894        self.assert_proper_cast(0, 'int', int)
895        self.assert_proper_cast(0, 'smallint', int)
896        self.assert_proper_cast(0, 'oid', int)
897        self.assert_proper_cast(0, 'cid', int)
898        self.assert_proper_cast(0, 'xid', int)
899
900    def testLong(self):
901        self.assert_proper_cast(0, 'bigint', long)
902
903    def testFloat(self):
904        self.assert_proper_cast(0, 'float', float)
905        self.assert_proper_cast(0, 'real', float)
906        self.assert_proper_cast(0, 'double', float)
907        self.assert_proper_cast(0, 'double precision', float)
908        self.assert_proper_cast('infinity', 'float', float)
909
910    def testFloat(self):
911        decimal = pg.get_decimal()
912        self.assert_proper_cast(decimal(0), 'numeric', decimal)
913        self.assert_proper_cast(decimal(0), 'decimal', decimal)
914
915    def testMoney(self):
916        decimal = pg.get_decimal()
917        self.assert_proper_cast(decimal('0'), 'money', decimal)
918
919    def testBool(self):
920        bool_type = bool if pg.get_bool() else str
921        self.assert_proper_cast('f', 'bool', bool_type)
922
923    def testDate(self):
924        self.assert_proper_cast('1956-01-31', 'date', str)
925        self.assert_proper_cast('0', 'interval', str)
926        self.assert_proper_cast('08:42', 'time', str)
927        self.assert_proper_cast('08:42', 'timetz', str)
928        self.assert_proper_cast('1956-01-31 08:42', 'timestamp', str)
929        self.assert_proper_cast('1956-01-31 08:42', 'timestamptz', str)
930
931    def testText(self):
932        self.assert_proper_cast('', 'text', str)
933        self.assert_proper_cast('', 'char', str)
934        self.assert_proper_cast('', 'bpchar', str)
935        self.assert_proper_cast('', 'varchar', str)
936
937    def testBytea(self):
938        self.assert_proper_cast('', 'bytea', bytes)
939
940    def testJson(self):
941        self.assert_proper_cast('{}', 'json', dict)
942
943
944class TestInserttable(unittest.TestCase):
945    """Test inserttable method."""
946
947    cls_set_up = False
948
949    @classmethod
950    def setUpClass(cls):
951        c = connect()
952        c.query("drop table if exists test cascade")
953        c.query("create table test ("
954            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
955            "d numeric, f4 real, f8 double precision, m money,"
956            "c char(1), v4 varchar(4), c4 char(4), t text)")
957        # Check whether the test database uses SQL_ASCII - this means
958        # that it does not consider encoding when calculating lengths.
959        c.query("set client_encoding=utf8")
960        cls.has_encoding = c.query(
961            "select length('À') - length('a')").getresult()[0][0] == 0
962        c.close()
963        cls.cls_set_up = True
964
965    @classmethod
966    def tearDownClass(cls):
967        c = connect()
968        c.query("drop table test cascade")
969        c.close()
970
971    def setUp(self):
972        self.assertTrue(self.cls_set_up)
973        self.c = connect()
974        self.c.query("set client_encoding=utf8")
975        self.c.query("set datestyle='ISO,YMD'")
976        self.c.query("set lc_monetary='C'")
977
978    def tearDown(self):
979        self.c.query("truncate table test")
980        self.c.close()
981
982    data = [
983        (-1, -1, long(-1), True, '1492-10-12', '08:30:00',
984            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
985        (0, 0, long(0), False, '1607-04-14', '09:00:00',
986            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
987        (1, 1, long(1), True, '1801-03-04', '03:45:00',
988            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
989        (2, 2, long(2), False, '1903-12-17', '11:22:00',
990            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
991
992    @classmethod
993    def db_len(cls, s, encoding):
994        if cls.has_encoding:
995            s = s if isinstance(s, unicode) else s.decode(encoding)
996        else:
997            s = s.encode(encoding) if isinstance(s, unicode) else s
998        return len(s)
999
1000    def get_back(self, encoding='utf-8'):
1001        """Convert boolean and decimal values back."""
1002        data = []
1003        for row in self.c.query("select * from test order by 1").getresult():
1004            self.assertIsInstance(row, tuple)
1005            row = list(row)
1006            if row[0] is not None:  # smallint
1007                self.assertIsInstance(row[0], int)
1008            if row[1] is not None:  # integer
1009                self.assertIsInstance(row[1], int)
1010            if row[2] is not None:  # bigint
1011                self.assertIsInstance(row[2], long)
1012            if row[3] is not None:  # boolean
1013                self.assertIsInstance(row[3], str)
1014                row[3] = {'f': False, 't': True}.get(row[3])
1015            if row[4] is not None:  # date
1016                self.assertIsInstance(row[4], str)
1017                self.assertTrue(row[4].replace('-', '').isdigit())
1018            if row[5] is not None:  # time
1019                self.assertIsInstance(row[5], str)
1020                self.assertTrue(row[5].replace(':', '').isdigit())
1021            if row[6] is not None:  # numeric
1022                self.assertIsInstance(row[6], Decimal)
1023                row[6] = float(row[6])
1024            if row[7] is not None:  # real
1025                self.assertIsInstance(row[7], float)
1026            if row[8] is not None:  # double precision
1027                self.assertIsInstance(row[8], float)
1028                row[8] = float(row[8])
1029            if row[9] is not None:  # money
1030                self.assertIsInstance(row[9], Decimal)
1031                row[9] = str(float(row[9]))
1032            if row[10] is not None:  # char(1)
1033                self.assertIsInstance(row[10], str)
1034                self.assertEqual(self.db_len(row[10], encoding), 1)
1035            if row[11] is not None:  # varchar(4)
1036                self.assertIsInstance(row[11], str)
1037                self.assertLessEqual(self.db_len(row[11], encoding), 4)
1038            if row[12] is not None:  # char(4)
1039                self.assertIsInstance(row[12], str)
1040                self.assertEqual(self.db_len(row[12], encoding), 4)
1041                row[12] = row[12].rstrip()
1042            if row[13] is not None:  # text
1043                self.assertIsInstance(row[13], str)
1044            row = tuple(row)
1045            data.append(row)
1046        return data
1047
1048    def testInserttable1Row(self):
1049        data = self.data[2:3]
1050        self.c.inserttable('test', data)
1051        self.assertEqual(self.get_back(), data)
1052
1053    def testInserttable4Rows(self):
1054        data = self.data
1055        self.c.inserttable('test', data)
1056        self.assertEqual(self.get_back(), data)
1057
1058    def testInserttableMultipleRows(self):
1059        num_rows = 100
1060        data = self.data[2:3] * num_rows
1061        self.c.inserttable('test', data)
1062        r = self.c.query("select count(*) from test").getresult()[0][0]
1063        self.assertEqual(r, num_rows)
1064
1065    def testInserttableMultipleCalls(self):
1066        num_rows = 10
1067        data = self.data[2:3]
1068        for _i in range(num_rows):
1069            self.c.inserttable('test', data)
1070        r = self.c.query("select count(*) from test").getresult()[0][0]
1071        self.assertEqual(r, num_rows)
1072
1073    def testInserttableNullValues(self):
1074        data = [(None,) * 14] * 100
1075        self.c.inserttable('test', data)
1076        self.assertEqual(self.get_back(), data)
1077
1078    def testInserttableMaxValues(self):
1079        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
1080            True, '2999-12-31', '11:59:59', 1e99,
1081            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
1082            "1", "1234", "1234", "1234" * 100)]
1083        self.c.inserttable('test', data)
1084        self.assertEqual(self.get_back(), data)
1085
1086    def testInserttableByteValues(self):
1087        try:
1088            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1089        except pg.ProgrammingError:
1090            self.skipTest("database does not support utf8")
1091        # non-ascii chars do not fit in char(1) when there is no encoding
1092        c = u'€' if self.has_encoding else u'$'
1093        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1094            0.0, 0.0, 0.0, u'0.0',
1095            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1096        row_bytes = tuple(s.encode('utf-8')
1097            if isinstance(s, unicode) else s for s in row_unicode)
1098        data = [row_bytes] * 2
1099        self.c.inserttable('test', data)
1100        if unicode_strings:
1101            data = [row_unicode] * 2
1102        self.assertEqual(self.get_back(), data)
1103
1104    def testInserttableUnicodeUtf8(self):
1105        try:
1106            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1107        except pg.ProgrammingError:
1108            self.skipTest("database does not support utf8")
1109        # non-ascii chars do not fit in char(1) when there is no encoding
1110        c = u'€' if self.has_encoding else u'$'
1111        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1112            0.0, 0.0, 0.0, u'0.0',
1113            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1114        data = [row_unicode] * 2
1115        self.c.inserttable('test', data)
1116        if not unicode_strings:
1117            row_bytes = tuple(s.encode('utf-8')
1118                if isinstance(s, unicode) else s for s in row_unicode)
1119            data = [row_bytes] * 2
1120        self.assertEqual(self.get_back(), data)
1121
1122    def testInserttableUnicodeLatin1(self):
1123
1124        try:
1125            self.c.query("set client_encoding=latin1")
1126            self.c.query("select 'Â¥'")
1127        except pg.ProgrammingError:
1128            self.skipTest("database does not support latin1")
1129        # non-ascii chars do not fit in char(1) when there is no encoding
1130        c = u'€' if self.has_encoding else u'$'
1131        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1132            0.0, 0.0, 0.0, u'0.0',
1133            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1134        data = [row_unicode]
1135        # cannot encode € sign with latin1 encoding
1136        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1137        row_unicode = tuple(s.replace(u'€', u'Â¥')
1138            if isinstance(s, unicode) else s for s in row_unicode)
1139        data = [row_unicode] * 2
1140        self.c.inserttable('test', data)
1141        if not unicode_strings:
1142            row_bytes = tuple(s.encode('latin1')
1143                if isinstance(s, unicode) else s for s in row_unicode)
1144            data = [row_bytes] * 2
1145        self.assertEqual(self.get_back('latin1'), data)
1146
1147    def testInserttableUnicodeLatin9(self):
1148        try:
1149            self.c.query("set client_encoding=latin9")
1150            self.c.query("select '€'")
1151        except pg.ProgrammingError:
1152            self.skipTest("database does not support latin9")
1153            return
1154        # non-ascii chars do not fit in char(1) when there is no encoding
1155        c = u'€' if self.has_encoding else u'$'
1156        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1157            0.0, 0.0, 0.0, u'0.0',
1158            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1159        data = [row_unicode] * 2
1160        self.c.inserttable('test', data)
1161        if not unicode_strings:
1162            row_bytes = tuple(s.encode('latin9')
1163                if isinstance(s, unicode) else s for s in row_unicode)
1164            data = [row_bytes] * 2
1165        self.assertEqual(self.get_back('latin9'), data)
1166
1167    def testInserttableNoEncoding(self):
1168        self.c.query("set client_encoding=sql_ascii")
1169        # non-ascii chars do not fit in char(1) when there is no encoding
1170        c = u'€' if self.has_encoding else u'$'
1171        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1172            0.0, 0.0, 0.0, u'0.0',
1173            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1174        data = [row_unicode]
1175        # cannot encode non-ascii unicode without a specific encoding
1176        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1177
1178
1179class TestDirectSocketAccess(unittest.TestCase):
1180    """Test copy command with direct socket access."""
1181
1182    cls_set_up = False
1183
1184    @classmethod
1185    def setUpClass(cls):
1186        c = connect()
1187        c.query("drop table if exists test cascade")
1188        c.query("create table test (i int, v varchar(16))")
1189        c.close()
1190        cls.cls_set_up = True
1191
1192    @classmethod
1193    def tearDownClass(cls):
1194        c = connect()
1195        c.query("drop table test cascade")
1196        c.close()
1197
1198    def setUp(self):
1199        self.assertTrue(self.cls_set_up)
1200        self.c = connect()
1201        self.c.query("set client_encoding=utf8")
1202
1203    def tearDown(self):
1204        self.c.query("truncate table test")
1205        self.c.close()
1206
1207    def testPutline(self):
1208        putline = self.c.putline
1209        query = self.c.query
1210        data = list(enumerate("apple pear plum cherry banana".split()))
1211        query("copy test from stdin")
1212        try:
1213            for i, v in data:
1214                putline("%d\t%s\n" % (i, v))
1215            putline("\\.\n")
1216        finally:
1217            self.c.endcopy()
1218        r = query("select * from test").getresult()
1219        self.assertEqual(r, data)
1220
1221    def testPutlineBytesAndUnicode(self):
1222        putline = self.c.putline
1223        query = self.c.query
1224        query("copy test from stdin")
1225        try:
1226            putline(u"47\tkÀse\n".encode('utf8'))
1227            putline("35\twÃŒrstel\n")
1228            putline(b"\\.\n")
1229        finally:
1230            self.c.endcopy()
1231        r = query("select * from test").getresult()
1232        self.assertEqual(r, [(47, 'kÀse'), (35, 'wÃŒrstel')])
1233
1234    def testGetline(self):
1235        getline = self.c.getline
1236        query = self.c.query
1237        data = list(enumerate("apple banana pear plum strawberry".split()))
1238        n = len(data)
1239        self.c.inserttable('test', data)
1240        query("copy test to stdout")
1241        try:
1242            for i in range(n + 2):
1243                v = getline()
1244                if i < n:
1245                    self.assertEqual(v, '%d\t%s' % data[i])
1246                elif i == n:
1247                    self.assertEqual(v, '\\.')
1248                else:
1249                    self.assertIsNone(v)
1250        finally:
1251            try:
1252                self.c.endcopy()
1253            except IOError:
1254                pass
1255
1256    def testGetlineBytesAndUnicode(self):
1257        getline = self.c.getline
1258        query = self.c.query
1259        data = [(54, u'kÀse'.encode('utf8')), (73, u'wÃŒrstel')]
1260        self.c.inserttable('test', data)
1261        query("copy test to stdout")
1262        try:
1263            v = getline()
1264            self.assertIsInstance(v, str)
1265            self.assertEqual(v, '54\tkÀse')
1266            v = getline()
1267            self.assertIsInstance(v, str)
1268            self.assertEqual(v, '73\twÃŒrstel')
1269            self.assertEqual(getline(), '\\.')
1270            self.assertIsNone(getline())
1271        finally:
1272            try:
1273                self.c.endcopy()
1274            except IOError:
1275                pass
1276
1277    def testParameterChecks(self):
1278        self.assertRaises(TypeError, self.c.putline)
1279        self.assertRaises(TypeError, self.c.getline, 'invalid')
1280        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
1281
1282
1283class TestNotificatons(unittest.TestCase):
1284    """Test notification support."""
1285
1286    def setUp(self):
1287        self.c = connect()
1288
1289    def tearDown(self):
1290        self.doCleanups()
1291        self.c.close()
1292
1293    def testGetNotify(self):
1294        getnotify = self.c.getnotify
1295        query = self.c.query
1296        self.assertIsNone(getnotify())
1297        query('listen test_notify')
1298        try:
1299            self.assertIsNone(self.c.getnotify())
1300            query("notify test_notify")
1301            r = getnotify()
1302            self.assertIsInstance(r, tuple)
1303            self.assertEqual(len(r), 3)
1304            self.assertIsInstance(r[0], str)
1305            self.assertIsInstance(r[1], int)
1306            self.assertIsInstance(r[2], str)
1307            self.assertEqual(r[0], 'test_notify')
1308            self.assertEqual(r[2], '')
1309            self.assertIsNone(self.c.getnotify())
1310            query("notify test_notify, 'test_payload'")
1311            r = getnotify()
1312            self.assertTrue(isinstance(r, tuple))
1313            self.assertEqual(len(r), 3)
1314            self.assertIsInstance(r[0], str)
1315            self.assertIsInstance(r[1], int)
1316            self.assertIsInstance(r[2], str)
1317            self.assertEqual(r[0], 'test_notify')
1318            self.assertEqual(r[2], 'test_payload')
1319            self.assertIsNone(getnotify())
1320        finally:
1321            query('unlisten test_notify')
1322
1323    def testGetNoticeReceiver(self):
1324        self.assertIsNone(self.c.get_notice_receiver())
1325
1326    def testSetNoticeReceiver(self):
1327        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
1328        self.assertRaises(TypeError, self.c.set_notice_receiver, 'invalid')
1329        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
1330        self.assertIsNone(self.c.set_notice_receiver(None))
1331
1332    def testSetAndGetNoticeReceiver(self):
1333        r = lambda notice: None
1334        self.assertIsNone(self.c.set_notice_receiver(r))
1335        self.assertIs(self.c.get_notice_receiver(), r)
1336        self.assertIsNone(self.c.set_notice_receiver(None))
1337        self.assertIsNone(self.c.get_notice_receiver())
1338
1339    def testNoticeReceiver(self):
1340        self.addCleanup(self.c.query, 'drop function bilbo_notice();')
1341        self.c.query('''create function bilbo_notice() returns void AS $$
1342            begin
1343                raise warning 'Bilbo was here!';
1344            end;
1345            $$ language plpgsql''')
1346        received = {}
1347
1348        def notice_receiver(notice):
1349            for attr in dir(notice):
1350                if attr.startswith('__'):
1351                    continue
1352                value = getattr(notice, attr)
1353                if isinstance(value, str):
1354                    value = value.replace('WARNUNG', 'WARNING')
1355                received[attr] = value
1356
1357        self.c.set_notice_receiver(notice_receiver)
1358        self.c.query('select bilbo_notice()')
1359        self.assertEqual(received, dict(
1360            pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
1361            severity='WARNING', primary='Bilbo was here!',
1362            detail=None, hint=None))
1363
1364
1365class TestConfigFunctions(unittest.TestCase):
1366    """Test the functions for changing default settings.
1367
1368    To test the effect of most of these functions, we need a database
1369    connection.  That's why they are covered in this test module.
1370    """
1371
1372    def setUp(self):
1373        self.c = connect()
1374        self.c.query("set client_encoding=utf8")
1375        self.c.query("set lc_monetary='C'")
1376
1377    def tearDown(self):
1378        self.c.close()
1379
1380    def testGetDecimalPoint(self):
1381        point = pg.get_decimal_point()
1382        # error if a parameter is passed
1383        self.assertRaises(TypeError, pg.get_decimal_point, point)
1384        self.assertIsInstance(point, str)
1385        self.assertEqual(point, '.')  # the default setting
1386        pg.set_decimal_point(',')
1387        try:
1388            r = pg.get_decimal_point()
1389        finally:
1390            pg.set_decimal_point(point)
1391        self.assertIsInstance(r, str)
1392        self.assertEqual(r, ',')
1393        pg.set_decimal_point("'")
1394        try:
1395            r = pg.get_decimal_point()
1396        finally:
1397            pg.set_decimal_point(point)
1398        self.assertIsInstance(r, str)
1399        self.assertEqual(r, "'")
1400        pg.set_decimal_point('')
1401        try:
1402            r = pg.get_decimal_point()
1403        finally:
1404            pg.set_decimal_point(point)
1405        self.assertIsNone(r)
1406        pg.set_decimal_point(None)
1407        try:
1408            r = pg.get_decimal_point()
1409        finally:
1410            pg.set_decimal_point(point)
1411        self.assertIsNone(r)
1412
1413    def testSetDecimalPoint(self):
1414        d = pg.Decimal
1415        point = pg.get_decimal_point()
1416        self.assertRaises(TypeError, pg.set_decimal_point)
1417        # error if decimal point is not a string
1418        self.assertRaises(TypeError, pg.set_decimal_point, 0)
1419        # error if more than one decimal point passed
1420        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
1421        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
1422        # error if decimal point is not a punctuation character
1423        self.assertRaises(TypeError, pg.set_decimal_point, '0')
1424        query = self.c.query
1425        # check that money values are interpreted as decimal values
1426        # only if decimal_point is set, and that the result is correct
1427        # only if it is set suitable for the current lc_monetary setting
1428        select_money = "select '34.25'::money"
1429        proper_money = d('34.25')
1430        bad_money = d('3425')
1431        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
1432        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
1433        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
1434        de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25',
1435            'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM')
1436        # first try with English localization (using the point)
1437        for lc in en_locales:
1438            try:
1439                query("set lc_monetary='%s'" % lc)
1440            except pg.ProgrammingError:
1441                pass
1442            else:
1443                break
1444        else:
1445            self.skipTest("cannot set English money locale")
1446        try:
1447            r = query(select_money)
1448        except pg.ProgrammingError:
1449            # this can happen if the currency signs cannot be
1450            # converted using the encoding of the test database
1451            self.skipTest("database does not support English money")
1452        pg.set_decimal_point(None)
1453        try:
1454            r = r.getresult()[0][0]
1455        finally:
1456            pg.set_decimal_point(point)
1457        self.assertIsInstance(r, str)
1458        self.assertIn(r, en_money)
1459        r = query(select_money)
1460        pg.set_decimal_point('')
1461        try:
1462            r = r.getresult()[0][0]
1463        finally:
1464            pg.set_decimal_point(point)
1465        self.assertIsInstance(r, str)
1466        self.assertIn(r, en_money)
1467        r = query(select_money)
1468        pg.set_decimal_point('.')
1469        try:
1470            r = r.getresult()[0][0]
1471        finally:
1472            pg.set_decimal_point(point)
1473        self.assertIsInstance(r, d)
1474        self.assertEqual(r, proper_money)
1475        r = query(select_money)
1476        pg.set_decimal_point(',')
1477        try:
1478            r = r.getresult()[0][0]
1479        finally:
1480            pg.set_decimal_point(point)
1481        self.assertIsInstance(r, d)
1482        self.assertEqual(r, bad_money)
1483        r = query(select_money)
1484        pg.set_decimal_point("'")
1485        try:
1486            r = r.getresult()[0][0]
1487        finally:
1488            pg.set_decimal_point(point)
1489        self.assertIsInstance(r, d)
1490        self.assertEqual(r, bad_money)
1491        # then try with German localization (using the comma)
1492        for lc in de_locales:
1493            try:
1494                query("set lc_monetary='%s'" % lc)
1495            except pg.ProgrammingError:
1496                pass
1497            else:
1498                break
1499        else:
1500            self.skipTest("cannot set German money locale")
1501        select_money = select_money.replace('.', ',')
1502        try:
1503            r = query(select_money)
1504        except pg.ProgrammingError:
1505            self.skipTest("database does not support English money")
1506        pg.set_decimal_point(None)
1507        try:
1508            r = r.getresult()[0][0]
1509        finally:
1510            pg.set_decimal_point(point)
1511        self.assertIsInstance(r, str)
1512        self.assertIn(r, de_money)
1513        r = query(select_money)
1514        pg.set_decimal_point('')
1515        try:
1516            r = r.getresult()[0][0]
1517        finally:
1518            pg.set_decimal_point(point)
1519        self.assertIsInstance(r, str)
1520        self.assertIn(r, de_money)
1521        r = query(select_money)
1522        pg.set_decimal_point(',')
1523        try:
1524            r = r.getresult()[0][0]
1525        finally:
1526            pg.set_decimal_point(point)
1527        self.assertIsInstance(r, d)
1528        self.assertEqual(r, proper_money)
1529        r = query(select_money)
1530        pg.set_decimal_point('.')
1531        try:
1532            r = r.getresult()[0][0]
1533        finally:
1534            pg.set_decimal_point(point)
1535        self.assertEqual(r, bad_money)
1536        r = query(select_money)
1537        pg.set_decimal_point("'")
1538        try:
1539            r = r.getresult()[0][0]
1540        finally:
1541            pg.set_decimal_point(point)
1542        self.assertEqual(r, bad_money)
1543
1544    def testGetDecimal(self):
1545        decimal_class = pg.get_decimal()
1546        # error if a parameter is passed
1547        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
1548        self.assertIs(decimal_class, pg.Decimal)  # the default setting
1549        pg.set_decimal(int)
1550        try:
1551            r = pg.get_decimal()
1552        finally:
1553            pg.set_decimal(decimal_class)
1554        self.assertIs(r, int)
1555        r = pg.get_decimal()
1556        self.assertIs(r, decimal_class)
1557
1558    def testSetDecimal(self):
1559        decimal_class = pg.get_decimal()
1560        # error if no parameter is passed
1561        self.assertRaises(TypeError, pg.set_decimal)
1562        query = self.c.query
1563        try:
1564            r = query("select 3425::numeric")
1565        except pg.ProgrammingError:
1566            self.skipTest('database does not support numeric')
1567        r = r.getresult()[0][0]
1568        self.assertIsInstance(r, decimal_class)
1569        self.assertEqual(r, decimal_class('3425'))
1570        r = query("select 3425::numeric")
1571        pg.set_decimal(int)
1572        try:
1573            r = r.getresult()[0][0]
1574        finally:
1575            pg.set_decimal(decimal_class)
1576        self.assertNotIsInstance(r, decimal_class)
1577        self.assertIsInstance(r, int)
1578        self.assertEqual(r, int(3425))
1579
1580    def testGetBool(self):
1581        use_bool = pg.get_bool()
1582        # error if a parameter is passed
1583        self.assertRaises(TypeError, pg.get_bool, use_bool)
1584        self.assertIsInstance(use_bool, bool)
1585        self.assertIs(use_bool, False)  # the default setting
1586        pg.set_bool(True)
1587        try:
1588            r = pg.get_bool()
1589        finally:
1590            pg.set_bool(use_bool)
1591        self.assertIsInstance(r, bool)
1592        self.assertIs(r, True)
1593        pg.set_bool(False)
1594        try:
1595            r = pg.get_bool()
1596        finally:
1597            pg.set_bool(use_bool)
1598        self.assertIsInstance(r, bool)
1599        self.assertIs(r, False)
1600        pg.set_bool(1)
1601        try:
1602            r = pg.get_bool()
1603        finally:
1604            pg.set_bool(use_bool)
1605        self.assertIsInstance(r, bool)
1606        self.assertIs(r, True)
1607        pg.set_bool(0)
1608        try:
1609            r = pg.get_bool()
1610        finally:
1611            pg.set_bool(use_bool)
1612        self.assertIsInstance(r, bool)
1613        self.assertIs(r, False)
1614
1615    def testSetBool(self):
1616        use_bool = pg.get_bool()
1617        # error if no parameter is passed
1618        self.assertRaises(TypeError, pg.set_bool)
1619        query = self.c.query
1620        try:
1621            r = query("select true::bool")
1622        except pg.ProgrammingError:
1623            self.skipTest('database does not support bool')
1624        r = r.getresult()[0][0]
1625        self.assertIsInstance(r, str)
1626        self.assertEqual(r, 't')
1627        r = query("select true::bool")
1628        pg.set_bool(True)
1629        try:
1630            r = r.getresult()[0][0]
1631        finally:
1632            pg.set_bool(use_bool)
1633        self.assertIsInstance(r, bool)
1634        self.assertIs(r, True)
1635        r = query("select true::bool")
1636        pg.set_bool(False)
1637        try:
1638            r = r.getresult()[0][0]
1639        finally:
1640            pg.set_bool(use_bool)
1641        self.assertIsInstance(r, str)
1642        self.assertIs(r, 't')
1643
1644    def testGetNamedresult(self):
1645        namedresult = pg.get_namedresult()
1646        # error if a parameter is passed
1647        self.assertRaises(TypeError, pg.get_namedresult, namedresult)
1648        self.assertIs(namedresult, pg._namedresult)  # the default setting
1649
1650    def testSetNamedresult(self):
1651        namedresult = pg.get_namedresult()
1652        self.assertTrue(callable(namedresult))
1653
1654        query = self.c.query
1655
1656        r = query("select 1 as x, 2 as y").namedresult()[0]
1657        self.assertIsInstance(r, tuple)
1658        self.assertEqual(r, (1, 2))
1659        self.assertIsNot(type(r), tuple)
1660        self.assertEqual(r._fields, ('x', 'y'))
1661        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1662        self.assertEqual(r.__class__.__name__, 'Row')
1663
1664        def listresult(q):
1665            return [list(row) for row in q.getresult()]
1666
1667        pg.set_namedresult(listresult)
1668        try:
1669            r = pg.get_namedresult()
1670            self.assertIs(r, listresult)
1671            r = query("select 1 as x, 2 as y").namedresult()[0]
1672            self.assertIsInstance(r, list)
1673            self.assertEqual(r, [1, 2])
1674            self.assertIsNot(type(r), tuple)
1675            self.assertFalse(hasattr(r, '_fields'))
1676            self.assertNotEqual(r.__class__.__name__, 'Row')
1677        finally:
1678            pg.set_namedresult(namedresult)
1679
1680        r = pg.get_namedresult()
1681        self.assertIs(r, namedresult)
1682
1683
1684class TestStandaloneEscapeFunctions(unittest.TestCase):
1685    """Test pg escape functions.
1686
1687    The libpq interface memorizes some parameters of the last opened
1688    connection that influence the result of these functions.  Therefore
1689    we need to open a connection with fixed parameters prior to testing
1690    in order to ensure that the tests always run under the same conditions.
1691    That's why these tests are included in this test module.
1692    """
1693
1694    cls_set_up = False
1695
1696    @classmethod
1697    def setUpClass(cls):
1698        query = connect().query
1699        query('set client_encoding=sql_ascii')
1700        query('set standard_conforming_strings=off')
1701        query('set bytea_output=escape')
1702        cls.cls_set_up = True
1703
1704    def testEscapeString(self):
1705        self.assertTrue(self.cls_set_up)
1706        f = pg.escape_string
1707        r = f(b'plain')
1708        self.assertIsInstance(r, bytes)
1709        self.assertEqual(r, b'plain')
1710        r = f(u'plain')
1711        self.assertIsInstance(r, unicode)
1712        self.assertEqual(r, u'plain')
1713        r = f(u"das is' kÀse".encode('utf-8'))
1714        self.assertIsInstance(r, bytes)
1715        self.assertEqual(r, u"das is'' kÀse".encode('utf-8'))
1716        r = f(u"that's cheesy")
1717        self.assertIsInstance(r, unicode)
1718        self.assertEqual(r, u"that''s cheesy")
1719        r = f(r"It's bad to have a \ inside.")
1720        self.assertEqual(r, r"It''s bad to have a \\ inside.")
1721
1722    def testEscapeBytea(self):
1723        self.assertTrue(self.cls_set_up)
1724        f = pg.escape_bytea
1725        r = f(b'plain')
1726        self.assertIsInstance(r, bytes)
1727        self.assertEqual(r, b'plain')
1728        r = f(u'plain')
1729        self.assertIsInstance(r, unicode)
1730        self.assertEqual(r, u'plain')
1731        r = f(u"das is' kÀse".encode('utf-8'))
1732        self.assertIsInstance(r, bytes)
1733        self.assertEqual(r, b"das is'' k\\\\303\\\\244se")
1734        r = f(u"that's cheesy")
1735        self.assertIsInstance(r, unicode)
1736        self.assertEqual(r, u"that''s cheesy")
1737        r = f(b'O\x00ps\xff!')
1738        self.assertEqual(r, b'O\\\\000ps\\\\377!')
1739
1740
1741if __name__ == '__main__':
1742    unittest.main()
Note: See TracBrowser for help on using the repository browser.