source: trunk/tests/test_classic_connection.py @ 781

Last change on this file since 781 was 781, checked in by cito, 3 years ago

Add full support for PostgreSQL array types

At the core of this patch is a fast parser for the peculiar syntax of
literal array expressions in PostgreSQL that was added to the C module.
This is not trivial, because PostgreSQL arrays can be multidimensional
and the syntax is different from Python and SQL expressions.

The Python pg and pgdb modules make use of this parser so that they can
return database columns containing PostgreSQL arrays to Python as lists.
Also added quoting methods that allow passing PostgreSQL arrays as lists
to insert()/update() and execute/executemany(). These methods are simpler
and were implemented in Python but needed support from the regex module.

The patch also adds makes getresult() in pg automatically return bytea
values in unescaped form as bytes strings. Before, it was necessary to
call unescape_bytea manually. The pgdb module did this already.

The patch includes some more refactorings and simplifications regarding
the quoting and casting in pg and pgdb.

Some references to antique PostgreSQL types that are not used any more
in the supported PostgreSQL versions have been removed.

Also added documentation and tests for the new features.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 61.9 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17import threading
18import time
19import os
20
21import pg  # the module under test
22
23from decimal import Decimal
24
25# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
26# get our information from that.  Otherwise we use the defaults.
27# These tests should be run with various PostgreSQL versions and databases
28# created with different encodings and locales.  Particularly, make sure the
29# tests are running against databases created with both SQL_ASCII and UTF8.
30dbname = 'unittest'
31dbhost = None
32dbport = 5432
33
34try:
35    from .LOCAL_PyGreSQL import *
36except (ImportError, ValueError):
37    try:
38        from LOCAL_PyGreSQL import *
39    except ImportError:
40        pass
41
42try:
43    long
44except NameError:  # Python >= 3.0
45    long = int
46
47try:
48    unicode
49except NameError:  # Python >= 3.0
50    unicode = str
51
52unicode_strings = str is not bytes
53
54windows = os.name == 'nt'
55
56# There is a known a bug in libpq under Windows which can cause
57# the interface to crash when calling PQhost():
58do_not_ask_for_host = windows
59do_not_ask_for_host_reason = 'libpq issue on Windows'
60
61
62def connect():
63    """Create a basic pg connection to the test database."""
64    connection = pg.connect(dbname, dbhost, dbport)
65    connection.query("set client_min_messages=warning")
66    return connection
67
68
69class TestCanConnect(unittest.TestCase):
70    """Test whether a basic connection to PostgreSQL is possible."""
71
72    def testCanConnect(self):
73        try:
74            connection = connect()
75        except pg.Error as error:
76            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
77        try:
78            connection.close()
79        except pg.Error:
80            self.fail('Cannot close the database connection')
81
82
83class TestConnectObject(unittest.TestCase):
84    """Test existence of basic pg connection methods."""
85
86    def setUp(self):
87        self.connection = connect()
88
89    def tearDown(self):
90        try:
91            self.connection.close()
92        except pg.InternalError:
93            pass
94
95    def is_method(self, attribute):
96        """Check if given attribute on the connection is a method."""
97        if do_not_ask_for_host and attribute == 'host':
98            return False
99        return callable(getattr(self.connection, attribute))
100
101    def testClassName(self):
102        self.assertEqual(self.connection.__class__.__name__, 'Connection')
103
104    def testModuleName(self):
105        self.assertEqual(self.connection.__class__.__module__, 'pg')
106
107    def testStr(self):
108        r = str(self.connection)
109        self.assertTrue(r.startswith('<pg.Connection object'), r)
110
111    def testRepr(self):
112        r = repr(self.connection)
113        self.assertTrue(r.startswith('<pg.Connection object'), r)
114
115    def testAllConnectAttributes(self):
116        attributes = '''db error host options port
117            protocol_version server_version status user'''.split()
118        connection_attributes = [a for a in dir(self.connection)
119            if not a.startswith('__') and not self.is_method(a)]
120        self.assertEqual(attributes, connection_attributes)
121
122    def testAllConnectMethods(self):
123        methods = '''cancel close endcopy
124            escape_bytea escape_identifier escape_literal escape_string
125            fileno get_notice_receiver getline getlo getnotify
126            inserttable locreate loimport parameter putline query reset
127            set_notice_receiver source transaction'''.split()
128        connection_methods = [a for a in dir(self.connection)
129            if not a.startswith('__') and self.is_method(a)]
130        self.assertEqual(methods, connection_methods)
131
132    def testAttributeDb(self):
133        self.assertEqual(self.connection.db, dbname)
134
135    def testAttributeError(self):
136        error = self.connection.error
137        self.assertTrue(not error or 'krb5_' in error)
138
139    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
140    def testAttributeHost(self):
141        def_host = 'localhost'
142        self.assertIsInstance(self.connection.host, str)
143        self.assertEqual(self.connection.host, dbhost or def_host)
144
145    def testAttributeOptions(self):
146        no_options = ''
147        self.assertEqual(self.connection.options, no_options)
148
149    def testAttributePort(self):
150        def_port = 5432
151        self.assertIsInstance(self.connection.port, int)
152        self.assertEqual(self.connection.port, dbport or def_port)
153
154    def testAttributeProtocolVersion(self):
155        protocol_version = self.connection.protocol_version
156        self.assertIsInstance(protocol_version, int)
157        self.assertTrue(2 <= protocol_version < 4)
158
159    def testAttributeServerVersion(self):
160        server_version = self.connection.server_version
161        self.assertIsInstance(server_version, int)
162        self.assertTrue(70400 <= server_version < 100000)
163
164    def testAttributeStatus(self):
165        status_ok = 1
166        self.assertIsInstance(self.connection.status, int)
167        self.assertEqual(self.connection.status, status_ok)
168
169    def testAttributeUser(self):
170        no_user = 'Deprecated facility'
171        user = self.connection.user
172        self.assertTrue(user)
173        self.assertIsInstance(user, str)
174        self.assertNotEqual(user, no_user)
175
176    def testMethodQuery(self):
177        query = self.connection.query
178        query("select 1+1")
179        query("select 1+$1", (1,))
180        query("select 1+$1+$2", (2, 3))
181        query("select 1+$1+$2", [2, 3])
182
183    def testMethodQueryEmpty(self):
184        self.assertRaises(ValueError, self.connection.query, '')
185
186    def testMethodEndcopy(self):
187        try:
188            self.connection.endcopy()
189        except IOError:
190            pass
191
192    def testMethodClose(self):
193        self.connection.close()
194        try:
195            self.connection.reset()
196        except (pg.Error, TypeError):
197            pass
198        else:
199            self.fail('Reset should give an error for a closed connection')
200        self.assertRaises(pg.InternalError, self.connection.close)
201        try:
202            self.connection.query('select 1')
203        except (pg.Error, TypeError):
204            pass
205        else:
206            self.fail('Query should give an error for a closed connection')
207        self.connection = connect()
208
209    def testMethodReset(self):
210        query = self.connection.query
211        # check that client encoding gets reset
212        encoding = query('show client_encoding').getresult()[0][0].upper()
213        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
214        self.assertNotEqual(encoding, changed_encoding)
215        self.connection.query("set client_encoding=%s" % changed_encoding)
216        new_encoding = query('show client_encoding').getresult()[0][0].upper()
217        self.assertEqual(new_encoding, changed_encoding)
218        self.connection.reset()
219        new_encoding = query('show client_encoding').getresult()[0][0].upper()
220        self.assertNotEqual(new_encoding, changed_encoding)
221        self.assertEqual(new_encoding, encoding)
222
223    def testMethodCancel(self):
224        r = self.connection.cancel()
225        self.assertIsInstance(r, int)
226        self.assertEqual(r, 1)
227
228    def testCancelLongRunningThread(self):
229        errors = []
230
231        def sleep():
232            try:
233                self.connection.query('select pg_sleep(5)').getresult()
234            except pg.ProgrammingError as error:
235                errors.append(str(error))
236
237        thread = threading.Thread(target=sleep)
238        t1 = time.time()
239        thread.start()  # run the query
240        while 1:  # make sure the query is really running
241            time.sleep(0.1)
242            if thread.is_alive() or time.time() - t1 > 5:
243                break
244        r = self.connection.cancel()  # cancel the running query
245        thread.join()  # wait for the thread to end
246        t2 = time.time()
247
248        self.assertIsInstance(r, int)
249        self.assertEqual(r, 1)  # return code should be 1
250        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
251        self.assertTrue(errors)
252
253    def testMethodFileNo(self):
254        r = self.connection.fileno()
255        self.assertIsInstance(r, int)
256        self.assertGreaterEqual(r, 0)
257
258
259class TestSimpleQueries(unittest.TestCase):
260    """Test simple queries via a basic pg connection."""
261
262    def setUp(self):
263        self.c = connect()
264
265    def tearDown(self):
266        self.doCleanups()
267        self.c.close()
268
269    def testClassName(self):
270        r = self.c.query("select 1")
271        self.assertEqual(r.__class__.__name__, 'Query')
272
273    def testModuleName(self):
274        r = self.c.query("select 1")
275        self.assertEqual(r.__class__.__module__, 'pg')
276
277    def testStr(self):
278        q = ("select 1 as a, 'hello' as h, 'w' as world"
279            " union select 2, 'xyz', 'uvw'")
280        r = self.c.query(q)
281        self.assertEqual(str(r),
282            'a|  h  |world\n'
283            '-+-----+-----\n'
284            '1|hello|w    \n'
285            '2|xyz  |uvw  \n'
286            '(2 rows)')
287
288    def testRepr(self):
289        r = repr(self.c.query("select 1"))
290        self.assertTrue(r.startswith('<pg.Query object'), r)
291
292    def testSelect0(self):
293        q = "select 0"
294        self.c.query(q)
295
296    def testSelect0Semicolon(self):
297        q = "select 0;"
298        self.c.query(q)
299
300    def testSelectDotSemicolon(self):
301        q = "select .;"
302        self.assertRaises(pg.ProgrammingError, self.c.query, q)
303
304    def testGetresult(self):
305        q = "select 0"
306        result = [(0,)]
307        r = self.c.query(q).getresult()
308        self.assertIsInstance(r, list)
309        v = r[0]
310        self.assertIsInstance(v, tuple)
311        self.assertIsInstance(v[0], int)
312        self.assertEqual(r, result)
313
314    def testGetresultLong(self):
315        q = "select 9876543210"
316        result = long(9876543210)
317        self.assertIsInstance(result, long)
318        v = self.c.query(q).getresult()[0][0]
319        self.assertIsInstance(v, long)
320        self.assertEqual(v, result)
321
322    def testGetresultDecimal(self):
323        q = "select 98765432109876543210"
324        result = Decimal(98765432109876543210)
325        v = self.c.query(q).getresult()[0][0]
326        self.assertIsInstance(v, Decimal)
327        self.assertEqual(v, result)
328
329    def testGetresultString(self):
330        result = 'Hello, world!'
331        q = "select '%s'" % result
332        v = self.c.query(q).getresult()[0][0]
333        self.assertIsInstance(v, str)
334        self.assertEqual(v, result)
335
336    def testDictresult(self):
337        q = "select 0 as alias0"
338        result = [{'alias0': 0}]
339        r = self.c.query(q).dictresult()
340        self.assertIsInstance(r, list)
341        v = r[0]
342        self.assertIsInstance(v, dict)
343        self.assertIsInstance(v['alias0'], int)
344        self.assertEqual(r, result)
345
346    def testDictresultLong(self):
347        q = "select 9876543210 as longjohnsilver"
348        result = long(9876543210)
349        self.assertIsInstance(result, long)
350        v = self.c.query(q).dictresult()[0]['longjohnsilver']
351        self.assertIsInstance(v, long)
352        self.assertEqual(v, result)
353
354    def testDictresultDecimal(self):
355        q = "select 98765432109876543210 as longjohnsilver"
356        result = Decimal(98765432109876543210)
357        v = self.c.query(q).dictresult()[0]['longjohnsilver']
358        self.assertIsInstance(v, Decimal)
359        self.assertEqual(v, result)
360
361    def testDictresultString(self):
362        result = 'Hello, world!'
363        q = "select '%s' as greeting" % result
364        v = self.c.query(q).dictresult()[0]['greeting']
365        self.assertIsInstance(v, str)
366        self.assertEqual(v, result)
367
368    def testNamedresult(self):
369        q = "select 0 as alias0"
370        result = [(0,)]
371        r = self.c.query(q).namedresult()
372        self.assertEqual(r, result)
373        v = r[0]
374        self.assertEqual(v._fields, ('alias0',))
375        self.assertEqual(v.alias0, 0)
376
377    def testGet3Cols(self):
378        q = "select 1,2,3"
379        result = [(1, 2, 3)]
380        r = self.c.query(q).getresult()
381        self.assertEqual(r, result)
382
383    def testGet3DictCols(self):
384        q = "select 1 as a,2 as b,3 as c"
385        result = [dict(a=1, b=2, c=3)]
386        r = self.c.query(q).dictresult()
387        self.assertEqual(r, result)
388
389    def testGet3NamedCols(self):
390        q = "select 1 as a,2 as b,3 as c"
391        result = [(1, 2, 3)]
392        r = self.c.query(q).namedresult()
393        self.assertEqual(r, result)
394        v = r[0]
395        self.assertEqual(v._fields, ('a', 'b', 'c'))
396        self.assertEqual(v.b, 2)
397
398    def testGet3Rows(self):
399        q = "select 3 union select 1 union select 2 order by 1"
400        result = [(1,), (2,), (3,)]
401        r = self.c.query(q).getresult()
402        self.assertEqual(r, result)
403
404    def testGet3DictRows(self):
405        q = ("select 3 as alias3"
406            " union select 1 union select 2 order by 1")
407        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
408        r = self.c.query(q).dictresult()
409        self.assertEqual(r, result)
410
411    def testGet3NamedRows(self):
412        q = ("select 3 as alias3"
413            " union select 1 union select 2 order by 1")
414        result = [(1,), (2,), (3,)]
415        r = self.c.query(q).namedresult()
416        self.assertEqual(r, result)
417        for v in r:
418            self.assertEqual(v._fields, ('alias3',))
419
420    def testDictresultNames(self):
421        q = "select 'MixedCase' as MixedCaseAlias"
422        result = [{'mixedcasealias': 'MixedCase'}]
423        r = self.c.query(q).dictresult()
424        self.assertEqual(r, result)
425        q = "select 'MixedCase' as \"MixedCaseAlias\""
426        result = [{'MixedCaseAlias': 'MixedCase'}]
427        r = self.c.query(q).dictresult()
428        self.assertEqual(r, result)
429
430    def testNamedresultNames(self):
431        q = "select 'MixedCase' as MixedCaseAlias"
432        result = [('MixedCase',)]
433        r = self.c.query(q).namedresult()
434        self.assertEqual(r, result)
435        v = r[0]
436        self.assertEqual(v._fields, ('mixedcasealias',))
437        self.assertEqual(v.mixedcasealias, 'MixedCase')
438        q = "select 'MixedCase' as \"MixedCaseAlias\""
439        r = self.c.query(q).namedresult()
440        self.assertEqual(r, result)
441        v = r[0]
442        self.assertEqual(v._fields, ('MixedCaseAlias',))
443        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
444
445    def testBigGetresult(self):
446        num_cols = 100
447        num_rows = 100
448        q = "select " + ','.join(map(str, range(num_cols)))
449        q = ' union all '.join((q,) * num_rows)
450        r = self.c.query(q).getresult()
451        result = [tuple(range(num_cols))] * num_rows
452        self.assertEqual(r, result)
453
454    def testListfields(self):
455        q = ('select 0 as a, 0 as b, 0 as c,'
456            ' 0 as c, 0 as b, 0 as a,'
457            ' 0 as lowercase, 0 as UPPERCASE,'
458            ' 0 as MixedCase, 0 as "MixedCase",'
459            ' 0 as a_long_name_with_underscores,'
460            ' 0 as "A long name with Blanks"')
461        r = self.c.query(q).listfields()
462        result = ('a', 'b', 'c', 'c', 'b', 'a',
463            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
464            'a_long_name_with_underscores',
465            'A long name with Blanks')
466        self.assertEqual(r, result)
467
468    def testFieldname(self):
469        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
470        r = self.c.query(q).fieldname(2)
471        self.assertEqual(r, 'x')
472        r = self.c.query(q).fieldname(3)
473        self.assertEqual(r, 'y')
474
475    def testFieldnum(self):
476        q = "select 1 as x"
477        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
478        q = "select 1 as x"
479        r = self.c.query(q).fieldnum('x')
480        self.assertIsInstance(r, int)
481        self.assertEqual(r, 0)
482        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
483        r = self.c.query(q).fieldnum('x')
484        self.assertIsInstance(r, int)
485        self.assertEqual(r, 2)
486        r = self.c.query(q).fieldnum('y')
487        self.assertIsInstance(r, int)
488        self.assertEqual(r, 3)
489
490    def testNtuples(self):
491        q = "select 1 where false"
492        r = self.c.query(q).ntuples()
493        self.assertIsInstance(r, int)
494        self.assertEqual(r, 0)
495        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
496            " union select 5 as a, 6 as b, 7 as c, 8 as d")
497        r = self.c.query(q).ntuples()
498        self.assertIsInstance(r, int)
499        self.assertEqual(r, 2)
500        q = ("select 1 union select 2 union select 3"
501            " union select 4 union select 5 union select 6")
502        r = self.c.query(q).ntuples()
503        self.assertIsInstance(r, int)
504        self.assertEqual(r, 6)
505
506    def testQuery(self):
507        query = self.c.query
508        query("drop table if exists test_table")
509        self.addCleanup(query, "drop table test_table")
510        q = "create table test_table (n integer) with oids"
511        r = query(q)
512        self.assertIsNone(r)
513        q = "insert into test_table values (1)"
514        r = query(q)
515        self.assertIsInstance(r, int)
516        q = "insert into test_table select 2"
517        r = query(q)
518        self.assertIsInstance(r, int)
519        oid = r
520        q = "select oid from test_table where n=2"
521        r = query(q).getresult()
522        self.assertEqual(len(r), 1)
523        r = r[0]
524        self.assertEqual(len(r), 1)
525        r = r[0]
526        self.assertIsInstance(r, int)
527        self.assertEqual(r, oid)
528        q = "insert into test_table select 3 union select 4 union select 5"
529        r = query(q)
530        self.assertIsInstance(r, str)
531        self.assertEqual(r, '3')
532        q = "update test_table set n=4 where n<5"
533        r = query(q)
534        self.assertIsInstance(r, str)
535        self.assertEqual(r, '4')
536        q = "delete from test_table"
537        r = query(q)
538        self.assertIsInstance(r, str)
539        self.assertEqual(r, '5')
540
541
542class TestUnicodeQueries(unittest.TestCase):
543    """Test unicode strings as queries via a basic pg connection."""
544
545    def setUp(self):
546        self.c = connect()
547        self.c.query('set client_encoding=utf8')
548
549    def tearDown(self):
550        self.c.close()
551
552    def testGetresulAscii(self):
553        result = u'Hello, world!'
554        q = u"select '%s'" % result
555        v = self.c.query(q).getresult()[0][0]
556        self.assertIsInstance(v, str)
557        self.assertEqual(v, result)
558
559    def testDictresulAscii(self):
560        result = u'Hello, world!'
561        q = u"select '%s' as greeting" % result
562        v = self.c.query(q).dictresult()[0]['greeting']
563        self.assertIsInstance(v, str)
564        self.assertEqual(v, result)
565
566    def testGetresultUtf8(self):
567        result = u'Hello, wörld & ЌОр!'
568        q = u"select '%s'" % result
569        if not unicode_strings:
570            result = result.encode('utf8')
571        # pass the query as unicode
572        try:
573            v = self.c.query(q).getresult()[0][0]
574        except pg.ProgrammingError:
575            self.skipTest("database does not support utf8")
576        self.assertIsInstance(v, str)
577        self.assertEqual(v, result)
578        q = q.encode('utf8')
579        # pass the query as bytes
580        v = self.c.query(q).getresult()[0][0]
581        self.assertIsInstance(v, str)
582        self.assertEqual(v, result)
583
584    def testDictresultUtf8(self):
585        result = u'Hello, wörld & ЌОр!'
586        q = u"select '%s' as greeting" % result
587        if not unicode_strings:
588            result = result.encode('utf8')
589        try:
590            v = self.c.query(q).dictresult()[0]['greeting']
591        except pg.ProgrammingError:
592            self.skipTest("database does not support utf8")
593        self.assertIsInstance(v, str)
594        self.assertEqual(v, result)
595        q = q.encode('utf8')
596        v = self.c.query(q).dictresult()[0]['greeting']
597        self.assertIsInstance(v, str)
598        self.assertEqual(v, result)
599
600    def testDictresultLatin1(self):
601        try:
602            self.c.query('set client_encoding=latin1')
603        except pg.ProgrammingError:
604            self.skipTest("database does not support latin1")
605        result = u'Hello, wörld!'
606        q = u"select '%s'" % result
607        if not unicode_strings:
608            result = result.encode('latin1')
609        v = self.c.query(q).getresult()[0][0]
610        self.assertIsInstance(v, str)
611        self.assertEqual(v, result)
612        q = q.encode('latin1')
613        v = self.c.query(q).getresult()[0][0]
614        self.assertIsInstance(v, str)
615        self.assertEqual(v, result)
616
617    def testDictresultLatin1(self):
618        try:
619            self.c.query('set client_encoding=latin1')
620        except pg.ProgrammingError:
621            self.skipTest("database does not support latin1")
622        result = u'Hello, wörld!'
623        q = u"select '%s' as greeting" % result
624        if not unicode_strings:
625            result = result.encode('latin1')
626        v = self.c.query(q).dictresult()[0]['greeting']
627        self.assertIsInstance(v, str)
628        self.assertEqual(v, result)
629        q = q.encode('latin1')
630        v = self.c.query(q).dictresult()[0]['greeting']
631        self.assertIsInstance(v, str)
632        self.assertEqual(v, result)
633
634    def testGetresultCyrillic(self):
635        try:
636            self.c.query('set client_encoding=iso_8859_5')
637        except pg.ProgrammingError:
638            self.skipTest("database does not support cyrillic")
639        result = u'Hello, ЌОр!'
640        q = u"select '%s'" % result
641        if not unicode_strings:
642            result = result.encode('cyrillic')
643        v = self.c.query(q).getresult()[0][0]
644        self.assertIsInstance(v, str)
645        self.assertEqual(v, result)
646        q = q.encode('cyrillic')
647        v = self.c.query(q).getresult()[0][0]
648        self.assertIsInstance(v, str)
649        self.assertEqual(v, result)
650
651    def testDictresultCyrillic(self):
652        try:
653            self.c.query('set client_encoding=iso_8859_5')
654        except pg.ProgrammingError:
655            self.skipTest("database does not support cyrillic")
656        result = u'Hello, ЌОр!'
657        q = u"select '%s' as greeting" % result
658        if not unicode_strings:
659            result = result.encode('cyrillic')
660        v = self.c.query(q).dictresult()[0]['greeting']
661        self.assertIsInstance(v, str)
662        self.assertEqual(v, result)
663        q = q.encode('cyrillic')
664        v = self.c.query(q).dictresult()[0]['greeting']
665        self.assertIsInstance(v, str)
666        self.assertEqual(v, result)
667
668    def testGetresultLatin9(self):
669        try:
670            self.c.query('set client_encoding=latin9')
671        except pg.ProgrammingError:
672            self.skipTest("database does not support latin9")
673        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
674        q = u"select '%s'" % result
675        if not unicode_strings:
676            result = result.encode('latin9')
677        v = self.c.query(q).getresult()[0][0]
678        self.assertIsInstance(v, str)
679        self.assertEqual(v, result)
680        q = q.encode('latin9')
681        v = self.c.query(q).getresult()[0][0]
682        self.assertIsInstance(v, str)
683        self.assertEqual(v, result)
684
685    def testDictresultLatin9(self):
686        try:
687            self.c.query('set client_encoding=latin9')
688        except pg.ProgrammingError:
689            self.skipTest("database does not support latin9")
690        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
691        q = u"select '%s' as menu" % result
692        if not unicode_strings:
693            result = result.encode('latin9')
694        v = self.c.query(q).dictresult()[0]['menu']
695        self.assertIsInstance(v, str)
696        self.assertEqual(v, result)
697        q = q.encode('latin9')
698        v = self.c.query(q).dictresult()[0]['menu']
699        self.assertIsInstance(v, str)
700        self.assertEqual(v, result)
701
702
703class TestParamQueries(unittest.TestCase):
704    """Test queries with parameters via a basic pg connection."""
705
706    def setUp(self):
707        self.c = connect()
708        self.c.query('set client_encoding=utf8')
709
710    def tearDown(self):
711        self.c.close()
712
713    def testQueryWithNoneParam(self):
714        self.assertRaises(TypeError, self.c.query, "select $1", None)
715        self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None)
716        self.assertEqual(self.c.query("select $1::integer", (None,)
717            ).getresult(), [(None,)])
718        self.assertEqual(self.c.query("select $1::text", [None]
719            ).getresult(), [(None,)])
720        self.assertEqual(self.c.query("select $1::text", [[None]]
721            ).getresult(), [(None,)])
722
723    def testQueryWithBoolParams(self, use_bool=None):
724        query = self.c.query
725        if use_bool is not None:
726            use_bool_default = pg.get_bool()
727            pg.set_bool(use_bool)
728        try:
729            v_false, v_true = (False, True) if use_bool else 'ft'
730            r_false, r_true = [(v_false,)], [(v_true,)]
731            self.assertEqual(query("select false").getresult(), r_false)
732            self.assertEqual(query("select true").getresult(), r_true)
733            q = "select $1::bool"
734            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
735            self.assertEqual(query(q, ('f',)).getresult(), r_false)
736            self.assertEqual(query(q, ('t',)).getresult(), r_true)
737            self.assertEqual(query(q, ('false',)).getresult(), r_false)
738            self.assertEqual(query(q, ('true',)).getresult(), r_true)
739            self.assertEqual(query(q, ('n',)).getresult(), r_false)
740            self.assertEqual(query(q, ('y',)).getresult(), r_true)
741            self.assertEqual(query(q, (0,)).getresult(), r_false)
742            self.assertEqual(query(q, (1,)).getresult(), r_true)
743            self.assertEqual(query(q, (False,)).getresult(), r_false)
744            self.assertEqual(query(q, (True,)).getresult(), r_true)
745        finally:
746            if use_bool is not None:
747                pg.set_bool(use_bool_default)
748
749    def testQueryWithBoolParamsAndUseBool(self):
750        self.testQueryWithBoolParams(use_bool=True)
751
752    def testQueryWithIntParams(self):
753        query = self.c.query
754        self.assertEqual(query("select 1+1").getresult(), [(2,)])
755        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
756        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
757        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
758        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
759        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
760            [(Decimal('2'),)])
761        self.assertEqual(query("select 1, $1::integer", (2,)
762            ).getresult(), [(1, 2)])
763        self.assertEqual(query("select 1 union select $1::integer", (2,)
764            ).getresult(), [(1,), (2,)])
765        self.assertEqual(query("select $1::integer+$2", (1, 2)
766            ).getresult(), [(3,)])
767        self.assertEqual(query("select $1::integer+$2", [1, 2]
768            ).getresult(), [(3,)])
769        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))
770            ).getresult(), [(15,)])
771
772    def testQueryWithStrParams(self):
773        query = self.c.query
774        self.assertEqual(query("select $1||', world!'", ('Hello',)
775            ).getresult(), [('Hello, world!',)])
776        self.assertEqual(query("select $1||', world!'", ['Hello']
777            ).getresult(), [('Hello, world!',)])
778        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
779            ).getresult(), [('Hello, world!',)])
780        self.assertEqual(query("select $1::text", ('Hello, world!',)
781            ).getresult(), [('Hello, world!',)])
782        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
783            ).getresult(), [('Hello', 'world')])
784        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
785            ).getresult(), [('Hello', 'world')])
786        self.assertEqual(query("select $1::text union select $2::text",
787            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
788        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
789            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
790
791    def testQueryWithUnicodeParams(self):
792        query = self.c.query
793        try:
794            query('set client_encoding=utf8')
795            query("select 'wörld'").getresult()[0][0] == 'wörld'
796        except pg.ProgrammingError:
797            self.skipTest("database does not support utf8")
798        self.assertEqual(query("select $1||', '||$2||'!'",
799            ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)])
800
801    def testQueryWithUnicodeParamsLatin1(self):
802        query = self.c.query
803        try:
804            query('set client_encoding=latin1')
805            query("select 'wörld'").getresult()[0][0] == 'wörld'
806        except pg.ProgrammingError:
807            self.skipTest("database does not support latin1")
808        r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult()
809        if unicode_strings:
810            self.assertEqual(r, [('Hello, wörld!',)])
811        else:
812            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
813        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
814            ('Hello', u'ЌОр'))
815        query('set client_encoding=iso_8859_1')
816        r = query("select $1||', '||$2||'!'",
817            ('Hello', u'wörld')).getresult()
818        if unicode_strings:
819            self.assertEqual(r, [('Hello, wörld!',)])
820        else:
821            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
822        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
823            ('Hello', u'ЌОр'))
824        query('set client_encoding=sql_ascii')
825        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
826            ('Hello', u'wörld'))
827
828    def testQueryWithUnicodeParamsCyrillic(self):
829        query = self.c.query
830        try:
831            query('set client_encoding=iso_8859_5')
832            query("select 'ЌОр'").getresult()[0][0] == 'ЌОр'
833        except pg.ProgrammingError:
834            self.skipTest("database does not support cyrillic")
835        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
836            ('Hello', u'wörld'))
837        r = query("select $1||', '||$2||'!'",
838            ('Hello', u'ЌОр')).getresult()
839        if unicode_strings:
840            self.assertEqual(r, [('Hello, ЌОр!',)])
841        else:
842            self.assertEqual(r, [(u'Hello, ЌОр!'.encode('cyrillic'),)])
843        query('set client_encoding=sql_ascii')
844        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
845            ('Hello', u'ЌОр!'))
846
847    def testQueryWithMixedParams(self):
848        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
849            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
850        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
851            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
852
853    def testQueryWithDuplicateParams(self):
854        self.assertRaises(pg.ProgrammingError,
855            self.c.query, "select $1+$1", (1,))
856        self.assertRaises(pg.ProgrammingError,
857            self.c.query, "select $1+$1", (1, 2))
858
859    def testQueryWithZeroParams(self):
860        self.assertEqual(self.c.query("select 1+1", []
861            ).getresult(), [(2,)])
862
863    def testQueryWithGarbage(self):
864        garbage = r"'\{}+()-#[]oo324"
865        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
866            ).dictresult(), [{'garbage': garbage}])
867
868
869class TestQueryResultTypes(unittest.TestCase):
870    """Test proper result types via a basic pg connection."""
871
872    def setUp(self):
873        self.c = connect()
874        self.c.query('set client_encoding=utf8')
875        self.c.query("set datestyle='ISO,YMD'")
876
877    def tearDown(self):
878        self.c.close()
879
880    def assert_proper_cast(self, value, pgtype, pytype):
881        q = 'select $1::%s' % (pgtype,)
882        r = self.c.query(q, (value,)).getresult()[0][0]
883        self.assertIsInstance(r, pytype)
884        if isinstance(value, (bytes, str)):
885            if not value or '{':
886                value = '"%s"' % value
887        value = '{%s}' % value
888        r = self.c.query(q + '[]', (value,)).getresult()[0][0]
889        self.assertIsInstance(r, list)
890        self.assertEqual(len(r), 1)
891        self.assertIsInstance(r[0], pytype)
892
893    def testInt(self):
894        self.assert_proper_cast(0, 'int', int)
895        self.assert_proper_cast(0, 'smallint', int)
896        self.assert_proper_cast(0, 'oid', int)
897        self.assert_proper_cast(0, 'cid', int)
898        self.assert_proper_cast(0, 'xid', int)
899
900    def testLong(self):
901        self.assert_proper_cast(0, 'bigint', long)
902
903    def testFloat(self):
904        self.assert_proper_cast(0, 'float', float)
905        self.assert_proper_cast(0, 'real', float)
906        self.assert_proper_cast(0, 'double', float)
907        self.assert_proper_cast(0, 'double precision', float)
908        self.assert_proper_cast('infinity', 'float', float)
909
910    def testFloat(self):
911        decimal = pg.get_decimal()
912        self.assert_proper_cast(decimal(0), 'numeric', decimal)
913        self.assert_proper_cast(decimal(0), 'decimal', decimal)
914
915    def testMoney(self):
916        decimal = pg.get_decimal()
917        self.assert_proper_cast(decimal('0'), 'money', decimal)
918
919    def testBool(self):
920        bool_type = bool if pg.get_bool() else str
921        self.assert_proper_cast('f', 'bool', bool_type)
922
923    def testDate(self):
924        self.assert_proper_cast('1956-01-31', 'date', str)
925        self.assert_proper_cast('0', 'interval', str)
926        self.assert_proper_cast('08:42', 'time', str)
927        self.assert_proper_cast('08:42', 'timetz', str)
928        self.assert_proper_cast('1956-01-31 08:42', 'timestamp', str)
929        self.assert_proper_cast('1956-01-31 08:42', 'timestamptz', str)
930
931    def testText(self):
932        self.assert_proper_cast('', 'text', str)
933        self.assert_proper_cast('', 'char', str)
934        self.assert_proper_cast('', 'bpchar', str)
935        self.assert_proper_cast('', 'varchar', str)
936
937    def testBytea(self):
938        self.assert_proper_cast('', 'bytea', bytes)
939
940    def testJson(self):
941        self.assert_proper_cast('{}', 'json', dict)
942
943
944class TestInserttable(unittest.TestCase):
945    """Test inserttable method."""
946
947    cls_set_up = False
948
949    @classmethod
950    def setUpClass(cls):
951        c = connect()
952        c.query("drop table if exists test cascade")
953        c.query("create table test ("
954            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
955            "d numeric, f4 real, f8 double precision, m money,"
956            "c char(1), v4 varchar(4), c4 char(4), t text)")
957        # Check whether the test database uses SQL_ASCII - this means
958        # that it does not consider encoding when calculating lengths.
959        c.query("set client_encoding=utf8")
960        cls.has_encoding = c.query(
961            "select length('À') - length('a')").getresult()[0][0] == 0
962        c.close()
963        cls.cls_set_up = True
964
965    @classmethod
966    def tearDownClass(cls):
967        c = connect()
968        c.query("drop table test cascade")
969        c.close()
970
971    def setUp(self):
972        self.assertTrue(self.cls_set_up)
973        self.c = connect()
974        self.c.query("set client_encoding=utf8")
975        self.c.query("set datestyle='ISO,YMD'")
976        self.c.query("set lc_monetary='C'")
977
978    def tearDown(self):
979        self.c.query("truncate table test")
980        self.c.close()
981
982    data = [
983        (-1, -1, long(-1), True, '1492-10-12', '08:30:00',
984            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
985        (0, 0, long(0), False, '1607-04-14', '09:00:00',
986            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
987        (1, 1, long(1), True, '1801-03-04', '03:45:00',
988            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
989        (2, 2, long(2), False, '1903-12-17', '11:22:00',
990            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
991
992    @classmethod
993    def db_len(cls, s, encoding):
994        if cls.has_encoding:
995            s = s if isinstance(s, unicode) else s.decode(encoding)
996        else:
997            s = s.encode(encoding) if isinstance(s, unicode) else s
998        return len(s)
999
1000    def get_back(self, encoding='utf-8'):
1001        """Convert boolean and decimal values back."""
1002        data = []
1003        for row in self.c.query("select * from test order by 1").getresult():
1004            self.assertIsInstance(row, tuple)
1005            row = list(row)
1006            if row[0] is not None:  # smallint
1007                self.assertIsInstance(row[0], int)
1008            if row[1] is not None:  # integer
1009                self.assertIsInstance(row[1], int)
1010            if row[2] is not None:  # bigint
1011                self.assertIsInstance(row[2], long)
1012            if row[3] is not None:  # boolean
1013                self.assertIsInstance(row[3], str)
1014                row[3] = {'f': False, 't': True}.get(row[3])
1015            if row[4] is not None:  # date
1016                self.assertIsInstance(row[4], str)
1017                self.assertTrue(row[4].replace('-', '').isdigit())
1018            if row[5] is not None:  # time
1019                self.assertIsInstance(row[5], str)
1020                self.assertTrue(row[5].replace(':', '').isdigit())
1021            if row[6] is not None:  # numeric
1022                self.assertIsInstance(row[6], Decimal)
1023                row[6] = float(row[6])
1024            if row[7] is not None:  # real
1025                self.assertIsInstance(row[7], float)
1026            if row[8] is not None:  # double precision
1027                self.assertIsInstance(row[8], float)
1028                row[8] = float(row[8])
1029            if row[9] is not None:  # money
1030                self.assertIsInstance(row[9], Decimal)
1031                row[9] = str(float(row[9]))
1032            if row[10] is not None:  # char(1)
1033                self.assertIsInstance(row[10], str)
1034                self.assertEqual(self.db_len(row[10], encoding), 1)
1035            if row[11] is not None:  # varchar(4)
1036                self.assertIsInstance(row[11], str)
1037                self.assertLessEqual(self.db_len(row[11], encoding), 4)
1038            if row[12] is not None:  # char(4)
1039                self.assertIsInstance(row[12], str)
1040                self.assertEqual(self.db_len(row[12], encoding), 4)
1041                row[12] = row[12].rstrip()
1042            if row[13] is not None:  # text
1043                self.assertIsInstance(row[13], str)
1044            row = tuple(row)
1045            data.append(row)
1046        return data
1047
1048    def testInserttable1Row(self):
1049        data = self.data[2:3]
1050        self.c.inserttable('test', data)
1051        self.assertEqual(self.get_back(), data)
1052
1053    def testInserttable4Rows(self):
1054        data = self.data
1055        self.c.inserttable('test', data)
1056        self.assertEqual(self.get_back(), data)
1057
1058    def testInserttableMultipleRows(self):
1059        num_rows = 100
1060        data = self.data[2:3] * num_rows
1061        self.c.inserttable('test', data)
1062        r = self.c.query("select count(*) from test").getresult()[0][0]
1063        self.assertEqual(r, num_rows)
1064
1065    def testInserttableMultipleCalls(self):
1066        num_rows = 10
1067        data = self.data[2:3]
1068        for _i in range(num_rows):
1069            self.c.inserttable('test', data)
1070        r = self.c.query("select count(*) from test").getresult()[0][0]
1071        self.assertEqual(r, num_rows)
1072
1073    def testInserttableNullValues(self):
1074        data = [(None,) * 14] * 100
1075        self.c.inserttable('test', data)
1076        self.assertEqual(self.get_back(), data)
1077
1078    def testInserttableMaxValues(self):
1079        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
1080            True, '2999-12-31', '11:59:59', 1e99,
1081            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
1082            "1", "1234", "1234", "1234" * 100)]
1083        self.c.inserttable('test', data)
1084        self.assertEqual(self.get_back(), data)
1085
1086    def testInserttableByteValues(self):
1087        try:
1088            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1089        except pg.ProgrammingError:
1090            self.skipTest("database does not support utf8")
1091        # non-ascii chars do not fit in char(1) when there is no encoding
1092        c = u'€' if self.has_encoding else u'$'
1093        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1094            0.0, 0.0, 0.0, u'0.0',
1095            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1096        row_bytes = tuple(s.encode('utf-8')
1097            if isinstance(s, unicode) else s for s in row_unicode)
1098        data = [row_bytes] * 2
1099        self.c.inserttable('test', data)
1100        if unicode_strings:
1101            data = [row_unicode] * 2
1102        self.assertEqual(self.get_back(), data)
1103
1104    def testInserttableUnicodeUtf8(self):
1105        try:
1106            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1107        except pg.ProgrammingError:
1108            self.skipTest("database does not support utf8")
1109        # non-ascii chars do not fit in char(1) when there is no encoding
1110        c = u'€' if self.has_encoding else u'$'
1111        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1112            0.0, 0.0, 0.0, u'0.0',
1113            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1114        data = [row_unicode] * 2
1115        self.c.inserttable('test', data)
1116        if not unicode_strings:
1117            row_bytes = tuple(s.encode('utf-8')
1118                if isinstance(s, unicode) else s for s in row_unicode)
1119            data = [row_bytes] * 2
1120        self.assertEqual(self.get_back(), data)
1121
1122    def testInserttableUnicodeLatin1(self):
1123
1124        try:
1125            self.c.query("set client_encoding=latin1")
1126            self.c.query("select 'Â¥'")
1127        except pg.ProgrammingError:
1128            self.skipTest("database does not support latin1")
1129        # non-ascii chars do not fit in char(1) when there is no encoding
1130        c = u'€' if self.has_encoding else u'$'
1131        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1132            0.0, 0.0, 0.0, u'0.0',
1133            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1134        data = [row_unicode]
1135        # cannot encode € sign with latin1 encoding
1136        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1137        row_unicode = tuple(s.replace(u'€', u'Â¥')
1138            if isinstance(s, unicode) else s for s in row_unicode)
1139        data = [row_unicode] * 2
1140        self.c.inserttable('test', data)
1141        if not unicode_strings:
1142            row_bytes = tuple(s.encode('latin1')
1143                if isinstance(s, unicode) else s for s in row_unicode)
1144            data = [row_bytes] * 2
1145        self.assertEqual(self.get_back('latin1'), data)
1146
1147    def testInserttableUnicodeLatin9(self):
1148        try:
1149            self.c.query("set client_encoding=latin9")
1150            self.c.query("select '€'")
1151        except pg.ProgrammingError:
1152            self.skipTest("database does not support latin9")
1153            return
1154        # non-ascii chars do not fit in char(1) when there is no encoding
1155        c = u'€' if self.has_encoding else u'$'
1156        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1157            0.0, 0.0, 0.0, u'0.0',
1158            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1159        data = [row_unicode] * 2
1160        self.c.inserttable('test', data)
1161        if not unicode_strings:
1162            row_bytes = tuple(s.encode('latin9')
1163                if isinstance(s, unicode) else s for s in row_unicode)
1164            data = [row_bytes] * 2
1165        self.assertEqual(self.get_back('latin9'), data)
1166
1167    def testInserttableNoEncoding(self):
1168        self.c.query("set client_encoding=sql_ascii")
1169        # non-ascii chars do not fit in char(1) when there is no encoding
1170        c = u'€' if self.has_encoding else u'$'
1171        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1172            0.0, 0.0, 0.0, u'0.0',
1173            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1174        data = [row_unicode]
1175        # cannot encode non-ascii unicode without a specific encoding
1176        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1177
1178
1179class TestDirectSocketAccess(unittest.TestCase):
1180    """Test copy command with direct socket access."""
1181
1182    cls_set_up = False
1183
1184    @classmethod
1185    def setUpClass(cls):
1186        c = connect()
1187        c.query("drop table if exists test cascade")
1188        c.query("create table test (i int, v varchar(16))")
1189        c.close()
1190        cls.cls_set_up = True
1191
1192    @classmethod
1193    def tearDownClass(cls):
1194        c = connect()
1195        c.query("drop table test cascade")
1196        c.close()
1197
1198    def setUp(self):
1199        self.assertTrue(self.cls_set_up)
1200        self.c = connect()
1201        self.c.query("set client_encoding=utf8")
1202
1203    def tearDown(self):
1204        self.c.query("truncate table test")
1205        self.c.close()
1206
1207    def testPutline(self):
1208        putline = self.c.putline
1209        query = self.c.query
1210        data = list(enumerate("apple pear plum cherry banana".split()))
1211        query("copy test from stdin")
1212        try:
1213            for i, v in data:
1214                putline("%d\t%s\n" % (i, v))
1215            putline("\\.\n")
1216        finally:
1217            self.c.endcopy()
1218        r = query("select * from test").getresult()
1219        self.assertEqual(r, data)
1220
1221    def testPutlineBytesAndUnicode(self):
1222        putline = self.c.putline
1223        query = self.c.query
1224        query("copy test from stdin")
1225        try:
1226            putline(u"47\tkÀse\n".encode('utf8'))
1227            putline("35\twÃŒrstel\n")
1228            putline(b"\\.\n")
1229        finally:
1230            self.c.endcopy()
1231        r = query("select * from test").getresult()
1232        self.assertEqual(r, [(47, 'kÀse'), (35, 'wÃŒrstel')])
1233
1234    def testGetline(self):
1235        getline = self.c.getline
1236        query = self.c.query
1237        data = list(enumerate("apple banana pear plum strawberry".split()))
1238        n = len(data)
1239        self.c.inserttable('test', data)
1240        query("copy test to stdout")
1241        try:
1242            for i in range(n + 2):
1243                v = getline()
1244                if i < n:
1245                    self.assertEqual(v, '%d\t%s' % data[i])
1246                elif i == n:
1247                    self.assertEqual(v, '\\.')
1248                else:
1249                    self.assertIsNone(v)
1250        finally:
1251            try:
1252                self.c.endcopy()
1253            except IOError:
1254                pass
1255
1256    def testGetlineBytesAndUnicode(self):
1257        getline = self.c.getline
1258        query = self.c.query
1259        data = [(54, u'kÀse'.encode('utf8')), (73, u'wÃŒrstel')]
1260        self.c.inserttable('test', data)
1261        query("copy test to stdout")
1262        try:
1263            v = getline()
1264            self.assertIsInstance(v, str)
1265            self.assertEqual(v, '54\tkÀse')
1266            v = getline()
1267            self.assertIsInstance(v, str)
1268            self.assertEqual(v, '73\twÃŒrstel')
1269            self.assertEqual(getline(), '\\.')
1270            self.assertIsNone(getline())
1271        finally:
1272            try:
1273                self.c.endcopy()
1274            except IOError:
1275                pass
1276
1277    def testParameterChecks(self):
1278        self.assertRaises(TypeError, self.c.putline)
1279        self.assertRaises(TypeError, self.c.getline, 'invalid')
1280        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
1281
1282
1283class TestNotificatons(unittest.TestCase):
1284    """Test notification support."""
1285
1286    def setUp(self):
1287        self.c = connect()
1288
1289    def tearDown(self):
1290        self.doCleanups()
1291        self.c.close()
1292
1293    def testGetNotify(self):
1294        getnotify = self.c.getnotify
1295        query = self.c.query
1296        self.assertIsNone(getnotify())
1297        query('listen test_notify')
1298        try:
1299            self.assertIsNone(self.c.getnotify())
1300            query("notify test_notify")
1301            r = getnotify()
1302            self.assertIsInstance(r, tuple)
1303            self.assertEqual(len(r), 3)
1304            self.assertIsInstance(r[0], str)
1305            self.assertIsInstance(r[1], int)
1306            self.assertIsInstance(r[2], str)
1307            self.assertEqual(r[0], 'test_notify')
1308            self.assertEqual(r[2], '')
1309            self.assertIsNone(self.c.getnotify())
1310            query("notify test_notify, 'test_payload'")
1311            r = getnotify()
1312            self.assertTrue(isinstance(r, tuple))
1313            self.assertEqual(len(r), 3)
1314            self.assertIsInstance(r[0], str)
1315            self.assertIsInstance(r[1], int)
1316            self.assertIsInstance(r[2], str)
1317            self.assertEqual(r[0], 'test_notify')
1318            self.assertEqual(r[2], 'test_payload')
1319            self.assertIsNone(getnotify())
1320        finally:
1321            query('unlisten test_notify')
1322
1323    def testGetNoticeReceiver(self):
1324        self.assertIsNone(self.c.get_notice_receiver())
1325
1326    def testSetNoticeReceiver(self):
1327        self.assertRaises(TypeError, self.c.set_notice_receiver, None)
1328        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
1329        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
1330
1331    def testSetAndGetNoticeReceiver(self):
1332        r = lambda notice: None
1333        self.assertIsNone(self.c.set_notice_receiver(r))
1334        self.assertIs(self.c.get_notice_receiver(), r)
1335
1336    def testNoticeReceiver(self):
1337        self.addCleanup(self.c.query, 'drop function bilbo_notice();')
1338        self.c.query('''create function bilbo_notice() returns void AS $$
1339            begin
1340                raise warning 'Bilbo was here!';
1341            end;
1342            $$ language plpgsql''')
1343        received = {}
1344
1345        def notice_receiver(notice):
1346            for attr in dir(notice):
1347                if attr.startswith('__'):
1348                    continue
1349                value = getattr(notice, attr)
1350                if isinstance(value, str):
1351                    value = value.replace('WARNUNG', 'WARNING')
1352                received[attr] = value
1353
1354        self.c.set_notice_receiver(notice_receiver)
1355        self.c.query('select bilbo_notice()')
1356        self.assertEqual(received, dict(
1357            pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
1358            severity='WARNING', primary='Bilbo was here!',
1359            detail=None, hint=None))
1360
1361
1362class TestConfigFunctions(unittest.TestCase):
1363    """Test the functions for changing default settings.
1364
1365    To test the effect of most of these functions, we need a database
1366    connection.  That's why they are covered in this test module.
1367    """
1368
1369    def setUp(self):
1370        self.c = connect()
1371        self.c.query("set client_encoding=utf8")
1372        self.c.query("set lc_monetary='C'")
1373
1374    def tearDown(self):
1375        self.c.close()
1376
1377    def testGetDecimalPoint(self):
1378        point = pg.get_decimal_point()
1379        # error if a parameter is passed
1380        self.assertRaises(TypeError, pg.get_decimal_point, point)
1381        self.assertIsInstance(point, str)
1382        self.assertEqual(point, '.')  # the default setting
1383        pg.set_decimal_point(',')
1384        try:
1385            r = pg.get_decimal_point()
1386        finally:
1387            pg.set_decimal_point(point)
1388        self.assertIsInstance(r, str)
1389        self.assertEqual(r, ',')
1390        pg.set_decimal_point("'")
1391        try:
1392            r = pg.get_decimal_point()
1393        finally:
1394            pg.set_decimal_point(point)
1395        self.assertIsInstance(r, str)
1396        self.assertEqual(r, "'")
1397        pg.set_decimal_point('')
1398        try:
1399            r = pg.get_decimal_point()
1400        finally:
1401            pg.set_decimal_point(point)
1402        self.assertIsNone(r)
1403        pg.set_decimal_point(None)
1404        try:
1405            r = pg.get_decimal_point()
1406        finally:
1407            pg.set_decimal_point(point)
1408        self.assertIsNone(r)
1409
1410    def testSetDecimalPoint(self):
1411        d = pg.Decimal
1412        point = pg.get_decimal_point()
1413        self.assertRaises(TypeError, pg.set_decimal_point)
1414        # error if decimal point is not a string
1415        self.assertRaises(TypeError, pg.set_decimal_point, 0)
1416        # error if more than one decimal point passed
1417        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
1418        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
1419        # error if decimal point is not a punctuation character
1420        self.assertRaises(TypeError, pg.set_decimal_point, '0')
1421        query = self.c.query
1422        # check that money values are interpreted as decimal values
1423        # only if decimal_point is set, and that the result is correct
1424        # only if it is set suitable for the current lc_monetary setting
1425        select_money = "select '34.25'::money"
1426        proper_money = d('34.25')
1427        bad_money = d('3425')
1428        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
1429        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
1430        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
1431        de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25',
1432            'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM')
1433        # first try with English localization (using the point)
1434        for lc in en_locales:
1435            try:
1436                query("set lc_monetary='%s'" % lc)
1437            except pg.ProgrammingError:
1438                pass
1439            else:
1440                break
1441        else:
1442            self.skipTest("cannot set English money locale")
1443        try:
1444            r = query(select_money)
1445        except pg.ProgrammingError:
1446            # this can happen if the currency signs cannot be
1447            # converted using the encoding of the test database
1448            self.skipTest("database does not support English money")
1449        pg.set_decimal_point(None)
1450        try:
1451            r = r.getresult()[0][0]
1452        finally:
1453            pg.set_decimal_point(point)
1454        self.assertIsInstance(r, str)
1455        self.assertIn(r, en_money)
1456        r = query(select_money)
1457        pg.set_decimal_point('')
1458        try:
1459            r = r.getresult()[0][0]
1460        finally:
1461            pg.set_decimal_point(point)
1462        self.assertIsInstance(r, str)
1463        self.assertIn(r, en_money)
1464        r = query(select_money)
1465        pg.set_decimal_point('.')
1466        try:
1467            r = r.getresult()[0][0]
1468        finally:
1469            pg.set_decimal_point(point)
1470        self.assertIsInstance(r, d)
1471        self.assertEqual(r, proper_money)
1472        r = query(select_money)
1473        pg.set_decimal_point(',')
1474        try:
1475            r = r.getresult()[0][0]
1476        finally:
1477            pg.set_decimal_point(point)
1478        self.assertIsInstance(r, d)
1479        self.assertEqual(r, bad_money)
1480        r = query(select_money)
1481        pg.set_decimal_point("'")
1482        try:
1483            r = r.getresult()[0][0]
1484        finally:
1485            pg.set_decimal_point(point)
1486        self.assertIsInstance(r, d)
1487        self.assertEqual(r, bad_money)
1488        # then try with German localization (using the comma)
1489        for lc in de_locales:
1490            try:
1491                query("set lc_monetary='%s'" % lc)
1492            except pg.ProgrammingError:
1493                pass
1494            else:
1495                break
1496        else:
1497            self.skipTest("cannot set German money locale")
1498        select_money = select_money.replace('.', ',')
1499        try:
1500            r = query(select_money)
1501        except pg.ProgrammingError:
1502            self.skipTest("database does not support English money")
1503        pg.set_decimal_point(None)
1504        try:
1505            r = r.getresult()[0][0]
1506        finally:
1507            pg.set_decimal_point(point)
1508        self.assertIsInstance(r, str)
1509        self.assertIn(r, de_money)
1510        r = query(select_money)
1511        pg.set_decimal_point('')
1512        try:
1513            r = r.getresult()[0][0]
1514        finally:
1515            pg.set_decimal_point(point)
1516        self.assertIsInstance(r, str)
1517        self.assertIn(r, de_money)
1518        r = query(select_money)
1519        pg.set_decimal_point(',')
1520        try:
1521            r = r.getresult()[0][0]
1522        finally:
1523            pg.set_decimal_point(point)
1524        self.assertIsInstance(r, d)
1525        self.assertEqual(r, proper_money)
1526        r = query(select_money)
1527        pg.set_decimal_point('.')
1528        try:
1529            r = r.getresult()[0][0]
1530        finally:
1531            pg.set_decimal_point(point)
1532        self.assertEqual(r, bad_money)
1533        r = query(select_money)
1534        pg.set_decimal_point("'")
1535        try:
1536            r = r.getresult()[0][0]
1537        finally:
1538            pg.set_decimal_point(point)
1539        self.assertEqual(r, bad_money)
1540
1541    def testGetDecimal(self):
1542        decimal_class = pg.get_decimal()
1543        # error if a parameter is passed
1544        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
1545        self.assertIs(decimal_class, pg.Decimal)  # the default setting
1546        pg.set_decimal(int)
1547        try:
1548            r = pg.get_decimal()
1549        finally:
1550            pg.set_decimal(decimal_class)
1551        self.assertIs(r, int)
1552        r = pg.get_decimal()
1553        self.assertIs(r, decimal_class)
1554
1555    def testSetDecimal(self):
1556        decimal_class = pg.get_decimal()
1557        # error if no parameter is passed
1558        self.assertRaises(TypeError, pg.set_decimal)
1559        query = self.c.query
1560        try:
1561            r = query("select 3425::numeric")
1562        except pg.ProgrammingError:
1563            self.skipTest('database does not support numeric')
1564        r = r.getresult()[0][0]
1565        self.assertIsInstance(r, decimal_class)
1566        self.assertEqual(r, decimal_class('3425'))
1567        r = query("select 3425::numeric")
1568        pg.set_decimal(int)
1569        try:
1570            r = r.getresult()[0][0]
1571        finally:
1572            pg.set_decimal(decimal_class)
1573        self.assertNotIsInstance(r, decimal_class)
1574        self.assertIsInstance(r, int)
1575        self.assertEqual(r, int(3425))
1576
1577    def testGetBool(self):
1578        use_bool = pg.get_bool()
1579        # error if a parameter is passed
1580        self.assertRaises(TypeError, pg.get_bool, use_bool)
1581        self.assertIsInstance(use_bool, bool)
1582        self.assertIs(use_bool, False)  # the default setting
1583        pg.set_bool(True)
1584        try:
1585            r = pg.get_bool()
1586        finally:
1587            pg.set_bool(use_bool)
1588        self.assertIsInstance(r, bool)
1589        self.assertIs(r, True)
1590        pg.set_bool(False)
1591        try:
1592            r = pg.get_bool()
1593        finally:
1594            pg.set_bool(use_bool)
1595        self.assertIsInstance(r, bool)
1596        self.assertIs(r, False)
1597        pg.set_bool(1)
1598        try:
1599            r = pg.get_bool()
1600        finally:
1601            pg.set_bool(use_bool)
1602        self.assertIsInstance(r, bool)
1603        self.assertIs(r, True)
1604        pg.set_bool(0)
1605        try:
1606            r = pg.get_bool()
1607        finally:
1608            pg.set_bool(use_bool)
1609        self.assertIsInstance(r, bool)
1610        self.assertIs(r, False)
1611
1612    def testSetBool(self):
1613        use_bool = pg.get_bool()
1614        # error if no parameter is passed
1615        self.assertRaises(TypeError, pg.set_bool)
1616        query = self.c.query
1617        try:
1618            r = query("select true::bool")
1619        except pg.ProgrammingError:
1620            self.skipTest('database does not support bool')
1621        r = r.getresult()[0][0]
1622        self.assertIsInstance(r, str)
1623        self.assertEqual(r, 't')
1624        r = query("select true::bool")
1625        pg.set_bool(True)
1626        try:
1627            r = r.getresult()[0][0]
1628        finally:
1629            pg.set_bool(use_bool)
1630        self.assertIsInstance(r, bool)
1631        self.assertIs(r, True)
1632        r = query("select true::bool")
1633        pg.set_bool(False)
1634        try:
1635            r = r.getresult()[0][0]
1636        finally:
1637            pg.set_bool(use_bool)
1638        self.assertIsInstance(r, str)
1639        self.assertIs(r, 't')
1640
1641    def testGetNamedresult(self):
1642        namedresult = pg.get_namedresult()
1643        # error if a parameter is passed
1644        self.assertRaises(TypeError, pg.get_namedresult, namedresult)
1645        self.assertIs(namedresult, pg._namedresult)  # the default setting
1646
1647    def testSetNamedresult(self):
1648        namedresult = pg.get_namedresult()
1649        self.assertTrue(callable(namedresult))
1650
1651        query = self.c.query
1652
1653        r = query("select 1 as x, 2 as y").namedresult()[0]
1654        self.assertIsInstance(r, tuple)
1655        self.assertEqual(r, (1, 2))
1656        self.assertIsNot(type(r), tuple)
1657        self.assertEqual(r._fields, ('x', 'y'))
1658        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1659        self.assertEqual(r.__class__.__name__, 'Row')
1660
1661        def listresult(q):
1662            return [list(row) for row in q.getresult()]
1663
1664        pg.set_namedresult(listresult)
1665        try:
1666            r = pg.get_namedresult()
1667            self.assertIs(r, listresult)
1668            r = query("select 1 as x, 2 as y").namedresult()[0]
1669            self.assertIsInstance(r, list)
1670            self.assertEqual(r, [1, 2])
1671            self.assertIsNot(type(r), tuple)
1672            self.assertFalse(hasattr(r, '_fields'))
1673            self.assertNotEqual(r.__class__.__name__, 'Row')
1674        finally:
1675            pg.set_namedresult(namedresult)
1676
1677        r = pg.get_namedresult()
1678        self.assertIs(r, namedresult)
1679
1680
1681class TestStandaloneEscapeFunctions(unittest.TestCase):
1682    """Test pg escape functions.
1683
1684    The libpq interface memorizes some parameters of the last opened
1685    connection that influence the result of these functions.  Therefore
1686    we need to open a connection with fixed parameters prior to testing
1687    in order to ensure that the tests always run under the same conditions.
1688    That's why these tests are included in this test module.
1689    """
1690
1691    cls_set_up = False
1692
1693    @classmethod
1694    def setUpClass(cls):
1695        query = connect().query
1696        query('set client_encoding=sql_ascii')
1697        query('set standard_conforming_strings=off')
1698        query('set bytea_output=escape')
1699        cls.cls_set_up = True
1700
1701    def testEscapeString(self):
1702        self.assertTrue(self.cls_set_up)
1703        f = pg.escape_string
1704        r = f(b'plain')
1705        self.assertIsInstance(r, bytes)
1706        self.assertEqual(r, b'plain')
1707        r = f(u'plain')
1708        self.assertIsInstance(r, unicode)
1709        self.assertEqual(r, u'plain')
1710        r = f(u"das is' kÀse".encode('utf-8'))
1711        self.assertIsInstance(r, bytes)
1712        self.assertEqual(r, u"das is'' kÀse".encode('utf-8'))
1713        r = f(u"that's cheesy")
1714        self.assertIsInstance(r, unicode)
1715        self.assertEqual(r, u"that''s cheesy")
1716        r = f(r"It's bad to have a \ inside.")
1717        self.assertEqual(r, r"It''s bad to have a \\ inside.")
1718
1719    def testEscapeBytea(self):
1720        self.assertTrue(self.cls_set_up)
1721        f = pg.escape_bytea
1722        r = f(b'plain')
1723        self.assertIsInstance(r, bytes)
1724        self.assertEqual(r, b'plain')
1725        r = f(u'plain')
1726        self.assertIsInstance(r, unicode)
1727        self.assertEqual(r, u'plain')
1728        r = f(u"das is' kÀse".encode('utf-8'))
1729        self.assertIsInstance(r, bytes)
1730        self.assertEqual(r, b"das is'' k\\\\303\\\\244se")
1731        r = f(u"that's cheesy")
1732        self.assertIsInstance(r, unicode)
1733        self.assertEqual(r, u"that''s cheesy")
1734        r = f(b'O\x00ps\xff!')
1735        self.assertEqual(r, b'O\\\\000ps\\\\377!')
1736
1737
1738if __name__ == '__main__':
1739    unittest.main()
Note: See TracBrowser for help on using the repository browser.