source: trunk/tests/test_classic_connection.py @ 814

Last change on this file since 814 was 814, checked in by cito, 4 years ago

Add typecasting of dates, times, timestamps, intervals

So far, PyGreSQL has returned these types only as strings (in various
formats depending on the DateStyle? setting) and left it to the user
to parse and interpret the strings. These types are now properly cast
into the corresponding detetime types of Python, and this works with
any setting of DatesStyle?, even if you change DateStyle? in the middle
of a database session.

To implement this, a fast method for getting the datestyle (cached and
without roundtrip to the database) has been added. Also, the typecast
mechanism has been extended so that typecast functions can optionally
also take the connection as argument.

The date and time typecast functions have been implemented in Python
using the new typecast registry and added to both pg and pgdb. Some
duplication of code in the two modules was unavoidable, since we don't
want the modules to be dependent of each other or install additional
helper modules. One day we might want to change this, put everything
in one package and factor out some of the functionality.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 65.9 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17import threading
18import time
19import os
20
21import pg  # the module under test
22
23from decimal import Decimal
24
25# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
26# get our information from that.  Otherwise we use the defaults.
27# These tests should be run with various PostgreSQL versions and databases
28# created with different encodings and locales.  Particularly, make sure the
29# tests are running against databases created with both SQL_ASCII and UTF8.
30dbname = 'unittest'
31dbhost = None
32dbport = 5432
33
34try:
35    from .LOCAL_PyGreSQL import *
36except (ImportError, ValueError):
37    try:
38        from LOCAL_PyGreSQL import *
39    except ImportError:
40        pass
41
42try:
43    long
44except NameError:  # Python >= 3.0
45    long = int
46
47try:
48    unicode
49except NameError:  # Python >= 3.0
50    unicode = str
51
52unicode_strings = str is not bytes
53
54windows = os.name == 'nt'
55
56# There is a known a bug in libpq under Windows which can cause
57# the interface to crash when calling PQhost():
58do_not_ask_for_host = windows
59do_not_ask_for_host_reason = 'libpq issue on Windows'
60
61
62def connect():
63    """Create a basic pg connection to the test database."""
64    connection = pg.connect(dbname, dbhost, dbport)
65    connection.query("set client_min_messages=warning")
66    return connection
67
68
69class TestCanConnect(unittest.TestCase):
70    """Test whether a basic connection to PostgreSQL is possible."""
71
72    def testCanConnect(self):
73        try:
74            connection = connect()
75        except pg.Error as error:
76            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
77        try:
78            connection.close()
79        except pg.Error:
80            self.fail('Cannot close the database connection')
81
82
83class TestConnectObject(unittest.TestCase):
84    """Test existence of basic pg connection methods."""
85
86    def setUp(self):
87        self.connection = connect()
88
89    def tearDown(self):
90        try:
91            self.connection.close()
92        except pg.InternalError:
93            pass
94
95    def is_method(self, attribute):
96        """Check if given attribute on the connection is a method."""
97        if do_not_ask_for_host and attribute == 'host':
98            return False
99        return callable(getattr(self.connection, attribute))
100
101    def testClassName(self):
102        self.assertEqual(self.connection.__class__.__name__, 'Connection')
103
104    def testModuleName(self):
105        self.assertEqual(self.connection.__class__.__module__, 'pg')
106
107    def testStr(self):
108        r = str(self.connection)
109        self.assertTrue(r.startswith('<pg.Connection object'), r)
110
111    def testRepr(self):
112        r = repr(self.connection)
113        self.assertTrue(r.startswith('<pg.Connection object'), r)
114
115    def testAllConnectAttributes(self):
116        attributes = '''db error host options port
117            protocol_version server_version status user'''.split()
118        connection_attributes = [a for a in dir(self.connection)
119            if not a.startswith('__') and not self.is_method(a)]
120        self.assertEqual(attributes, connection_attributes)
121
122    def testAllConnectMethods(self):
123        methods = '''cancel close date_format endcopy
124            escape_bytea escape_identifier escape_literal escape_string
125            fileno get_cast_hook get_notice_receiver getline getlo getnotify
126            inserttable locreate loimport parameter putline query reset
127            set_cast_hook set_notice_receiver source transaction'''.split()
128        connection_methods = [a for a in dir(self.connection)
129            if not a.startswith('__') and self.is_method(a)]
130        self.assertEqual(methods, connection_methods)
131
132    def testAttributeDb(self):
133        self.assertEqual(self.connection.db, dbname)
134
135    def testAttributeError(self):
136        error = self.connection.error
137        self.assertTrue(not error or 'krb5_' in error)
138
139    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
140    def testAttributeHost(self):
141        def_host = 'localhost'
142        self.assertIsInstance(self.connection.host, str)
143        self.assertEqual(self.connection.host, dbhost or def_host)
144
145    def testAttributeOptions(self):
146        no_options = ''
147        self.assertEqual(self.connection.options, no_options)
148
149    def testAttributePort(self):
150        def_port = 5432
151        self.assertIsInstance(self.connection.port, int)
152        self.assertEqual(self.connection.port, dbport or def_port)
153
154    def testAttributeProtocolVersion(self):
155        protocol_version = self.connection.protocol_version
156        self.assertIsInstance(protocol_version, int)
157        self.assertTrue(2 <= protocol_version < 4)
158
159    def testAttributeServerVersion(self):
160        server_version = self.connection.server_version
161        self.assertIsInstance(server_version, int)
162        self.assertTrue(70400 <= server_version < 100000)
163
164    def testAttributeStatus(self):
165        status_ok = 1
166        self.assertIsInstance(self.connection.status, int)
167        self.assertEqual(self.connection.status, status_ok)
168
169    def testAttributeUser(self):
170        no_user = 'Deprecated facility'
171        user = self.connection.user
172        self.assertTrue(user)
173        self.assertIsInstance(user, str)
174        self.assertNotEqual(user, no_user)
175
176    def testMethodQuery(self):
177        query = self.connection.query
178        query("select 1+1")
179        query("select 1+$1", (1,))
180        query("select 1+$1+$2", (2, 3))
181        query("select 1+$1+$2", [2, 3])
182
183    def testMethodQueryEmpty(self):
184        self.assertRaises(ValueError, self.connection.query, '')
185
186    def testMethodEndcopy(self):
187        try:
188            self.connection.endcopy()
189        except IOError:
190            pass
191
192    def testMethodClose(self):
193        self.connection.close()
194        try:
195            self.connection.reset()
196        except (pg.Error, TypeError):
197            pass
198        else:
199            self.fail('Reset should give an error for a closed connection')
200        self.assertRaises(pg.InternalError, self.connection.close)
201        try:
202            self.connection.query('select 1')
203        except (pg.Error, TypeError):
204            pass
205        else:
206            self.fail('Query should give an error for a closed connection')
207        self.connection = connect()
208
209    def testMethodReset(self):
210        query = self.connection.query
211        # check that client encoding gets reset
212        encoding = query('show client_encoding').getresult()[0][0].upper()
213        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
214        self.assertNotEqual(encoding, changed_encoding)
215        self.connection.query("set client_encoding=%s" % changed_encoding)
216        new_encoding = query('show client_encoding').getresult()[0][0].upper()
217        self.assertEqual(new_encoding, changed_encoding)
218        self.connection.reset()
219        new_encoding = query('show client_encoding').getresult()[0][0].upper()
220        self.assertNotEqual(new_encoding, changed_encoding)
221        self.assertEqual(new_encoding, encoding)
222
223    def testMethodCancel(self):
224        r = self.connection.cancel()
225        self.assertIsInstance(r, int)
226        self.assertEqual(r, 1)
227
228    def testCancelLongRunningThread(self):
229        errors = []
230
231        def sleep():
232            try:
233                self.connection.query('select pg_sleep(5)').getresult()
234            except pg.ProgrammingError as error:
235                errors.append(str(error))
236
237        thread = threading.Thread(target=sleep)
238        t1 = time.time()
239        thread.start()  # run the query
240        while 1:  # make sure the query is really running
241            time.sleep(0.1)
242            if thread.is_alive() or time.time() - t1 > 5:
243                break
244        r = self.connection.cancel()  # cancel the running query
245        thread.join()  # wait for the thread to end
246        t2 = time.time()
247
248        self.assertIsInstance(r, int)
249        self.assertEqual(r, 1)  # return code should be 1
250        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
251        self.assertTrue(errors)
252
253    def testMethodFileNo(self):
254        r = self.connection.fileno()
255        self.assertIsInstance(r, int)
256        self.assertGreaterEqual(r, 0)
257
258    def testMethodTransaction(self):
259        transaction = self.connection.transaction
260        self.assertRaises(TypeError, transaction, None)
261        self.assertEqual(transaction(), pg.TRANS_IDLE)
262        self.connection.query('begin')
263        self.assertEqual(transaction(), pg.TRANS_INTRANS)
264        self.connection.query('rollback')
265        self.assertEqual(transaction(), pg.TRANS_IDLE)
266
267    def testMethodParameter(self):
268        parameter = self.connection.parameter
269        query = self.connection.query
270        self.assertRaises(TypeError, parameter)
271        r = parameter('this server setting does not exist')
272        self.assertIsNone(r)
273        s = query('show server_version').getresult()[0][0].upper()
274        self.assertIsNotNone(s)
275        r = parameter('server_version')
276        self.assertEqual(r, s)
277        s = query('show server_encoding').getresult()[0][0].upper()
278        self.assertIsNotNone(s)
279        r = parameter('server_encoding')
280        self.assertEqual(r, s)
281        s = query('show client_encoding').getresult()[0][0].upper()
282        self.assertIsNotNone(s)
283        r = parameter('client_encoding')
284        self.assertEqual(r, s)
285        s = query('show server_encoding').getresult()[0][0].upper()
286        self.assertIsNotNone(s)
287        r = parameter('server_encoding')
288        self.assertEqual(r, s)
289
290
291class TestSimpleQueries(unittest.TestCase):
292    """Test simple queries via a basic pg connection."""
293
294    def setUp(self):
295        self.c = connect()
296
297    def tearDown(self):
298        self.doCleanups()
299        self.c.close()
300
301    def testClassName(self):
302        r = self.c.query("select 1")
303        self.assertEqual(r.__class__.__name__, 'Query')
304
305    def testModuleName(self):
306        r = self.c.query("select 1")
307        self.assertEqual(r.__class__.__module__, 'pg')
308
309    def testStr(self):
310        q = ("select 1 as a, 'hello' as h, 'w' as world"
311            " union select 2, 'xyz', 'uvw'")
312        r = self.c.query(q)
313        self.assertEqual(str(r),
314            'a|  h  |world\n'
315            '-+-----+-----\n'
316            '1|hello|w    \n'
317            '2|xyz  |uvw  \n'
318            '(2 rows)')
319
320    def testRepr(self):
321        r = repr(self.c.query("select 1"))
322        self.assertTrue(r.startswith('<pg.Query object'), r)
323
324    def testSelect0(self):
325        q = "select 0"
326        self.c.query(q)
327
328    def testSelect0Semicolon(self):
329        q = "select 0;"
330        self.c.query(q)
331
332    def testSelectDotSemicolon(self):
333        q = "select .;"
334        self.assertRaises(pg.ProgrammingError, self.c.query, q)
335
336    def testGetresult(self):
337        q = "select 0"
338        result = [(0,)]
339        r = self.c.query(q).getresult()
340        self.assertIsInstance(r, list)
341        v = r[0]
342        self.assertIsInstance(v, tuple)
343        self.assertIsInstance(v[0], int)
344        self.assertEqual(r, result)
345
346    def testGetresultLong(self):
347        q = "select 9876543210"
348        result = long(9876543210)
349        self.assertIsInstance(result, long)
350        v = self.c.query(q).getresult()[0][0]
351        self.assertIsInstance(v, long)
352        self.assertEqual(v, result)
353
354    def testGetresultDecimal(self):
355        q = "select 98765432109876543210"
356        result = Decimal(98765432109876543210)
357        v = self.c.query(q).getresult()[0][0]
358        self.assertIsInstance(v, Decimal)
359        self.assertEqual(v, result)
360
361    def testGetresultString(self):
362        result = 'Hello, world!'
363        q = "select '%s'" % result
364        v = self.c.query(q).getresult()[0][0]
365        self.assertIsInstance(v, str)
366        self.assertEqual(v, result)
367
368    def testDictresult(self):
369        q = "select 0 as alias0"
370        result = [{'alias0': 0}]
371        r = self.c.query(q).dictresult()
372        self.assertIsInstance(r, list)
373        v = r[0]
374        self.assertIsInstance(v, dict)
375        self.assertIsInstance(v['alias0'], int)
376        self.assertEqual(r, result)
377
378    def testDictresultLong(self):
379        q = "select 9876543210 as longjohnsilver"
380        result = long(9876543210)
381        self.assertIsInstance(result, long)
382        v = self.c.query(q).dictresult()[0]['longjohnsilver']
383        self.assertIsInstance(v, long)
384        self.assertEqual(v, result)
385
386    def testDictresultDecimal(self):
387        q = "select 98765432109876543210 as longjohnsilver"
388        result = Decimal(98765432109876543210)
389        v = self.c.query(q).dictresult()[0]['longjohnsilver']
390        self.assertIsInstance(v, Decimal)
391        self.assertEqual(v, result)
392
393    def testDictresultString(self):
394        result = 'Hello, world!'
395        q = "select '%s' as greeting" % result
396        v = self.c.query(q).dictresult()[0]['greeting']
397        self.assertIsInstance(v, str)
398        self.assertEqual(v, result)
399
400    def testNamedresult(self):
401        q = "select 0 as alias0"
402        result = [(0,)]
403        r = self.c.query(q).namedresult()
404        self.assertEqual(r, result)
405        v = r[0]
406        self.assertEqual(v._fields, ('alias0',))
407        self.assertEqual(v.alias0, 0)
408
409    def testGet3Cols(self):
410        q = "select 1,2,3"
411        result = [(1, 2, 3)]
412        r = self.c.query(q).getresult()
413        self.assertEqual(r, result)
414
415    def testGet3DictCols(self):
416        q = "select 1 as a,2 as b,3 as c"
417        result = [dict(a=1, b=2, c=3)]
418        r = self.c.query(q).dictresult()
419        self.assertEqual(r, result)
420
421    def testGet3NamedCols(self):
422        q = "select 1 as a,2 as b,3 as c"
423        result = [(1, 2, 3)]
424        r = self.c.query(q).namedresult()
425        self.assertEqual(r, result)
426        v = r[0]
427        self.assertEqual(v._fields, ('a', 'b', 'c'))
428        self.assertEqual(v.b, 2)
429
430    def testGet3Rows(self):
431        q = "select 3 union select 1 union select 2 order by 1"
432        result = [(1,), (2,), (3,)]
433        r = self.c.query(q).getresult()
434        self.assertEqual(r, result)
435
436    def testGet3DictRows(self):
437        q = ("select 3 as alias3"
438            " union select 1 union select 2 order by 1")
439        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
440        r = self.c.query(q).dictresult()
441        self.assertEqual(r, result)
442
443    def testGet3NamedRows(self):
444        q = ("select 3 as alias3"
445            " union select 1 union select 2 order by 1")
446        result = [(1,), (2,), (3,)]
447        r = self.c.query(q).namedresult()
448        self.assertEqual(r, result)
449        for v in r:
450            self.assertEqual(v._fields, ('alias3',))
451
452    def testDictresultNames(self):
453        q = "select 'MixedCase' as MixedCaseAlias"
454        result = [{'mixedcasealias': 'MixedCase'}]
455        r = self.c.query(q).dictresult()
456        self.assertEqual(r, result)
457        q = "select 'MixedCase' as \"MixedCaseAlias\""
458        result = [{'MixedCaseAlias': 'MixedCase'}]
459        r = self.c.query(q).dictresult()
460        self.assertEqual(r, result)
461
462    def testNamedresultNames(self):
463        q = "select 'MixedCase' as MixedCaseAlias"
464        result = [('MixedCase',)]
465        r = self.c.query(q).namedresult()
466        self.assertEqual(r, result)
467        v = r[0]
468        self.assertEqual(v._fields, ('mixedcasealias',))
469        self.assertEqual(v.mixedcasealias, 'MixedCase')
470        q = "select 'MixedCase' as \"MixedCaseAlias\""
471        r = self.c.query(q).namedresult()
472        self.assertEqual(r, result)
473        v = r[0]
474        self.assertEqual(v._fields, ('MixedCaseAlias',))
475        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
476
477    def testBigGetresult(self):
478        num_cols = 100
479        num_rows = 100
480        q = "select " + ','.join(map(str, range(num_cols)))
481        q = ' union all '.join((q,) * num_rows)
482        r = self.c.query(q).getresult()
483        result = [tuple(range(num_cols))] * num_rows
484        self.assertEqual(r, result)
485
486    def testListfields(self):
487        q = ('select 0 as a, 0 as b, 0 as c,'
488            ' 0 as c, 0 as b, 0 as a,'
489            ' 0 as lowercase, 0 as UPPERCASE,'
490            ' 0 as MixedCase, 0 as "MixedCase",'
491            ' 0 as a_long_name_with_underscores,'
492            ' 0 as "A long name with Blanks"')
493        r = self.c.query(q).listfields()
494        result = ('a', 'b', 'c', 'c', 'b', 'a',
495            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
496            'a_long_name_with_underscores',
497            'A long name with Blanks')
498        self.assertEqual(r, result)
499
500    def testFieldname(self):
501        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
502        r = self.c.query(q).fieldname(2)
503        self.assertEqual(r, 'x')
504        r = self.c.query(q).fieldname(3)
505        self.assertEqual(r, 'y')
506
507    def testFieldnum(self):
508        q = "select 1 as x"
509        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
510        q = "select 1 as x"
511        r = self.c.query(q).fieldnum('x')
512        self.assertIsInstance(r, int)
513        self.assertEqual(r, 0)
514        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
515        r = self.c.query(q).fieldnum('x')
516        self.assertIsInstance(r, int)
517        self.assertEqual(r, 2)
518        r = self.c.query(q).fieldnum('y')
519        self.assertIsInstance(r, int)
520        self.assertEqual(r, 3)
521
522    def testNtuples(self):
523        q = "select 1 where false"
524        r = self.c.query(q).ntuples()
525        self.assertIsInstance(r, int)
526        self.assertEqual(r, 0)
527        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
528            " union select 5 as a, 6 as b, 7 as c, 8 as d")
529        r = self.c.query(q).ntuples()
530        self.assertIsInstance(r, int)
531        self.assertEqual(r, 2)
532        q = ("select 1 union select 2 union select 3"
533            " union select 4 union select 5 union select 6")
534        r = self.c.query(q).ntuples()
535        self.assertIsInstance(r, int)
536        self.assertEqual(r, 6)
537
538    def testQuery(self):
539        query = self.c.query
540        query("drop table if exists test_table")
541        self.addCleanup(query, "drop table test_table")
542        q = "create table test_table (n integer) with oids"
543        r = query(q)
544        self.assertIsNone(r)
545        q = "insert into test_table values (1)"
546        r = query(q)
547        self.assertIsInstance(r, int)
548        q = "insert into test_table select 2"
549        r = query(q)
550        self.assertIsInstance(r, int)
551        oid = r
552        q = "select oid from test_table where n=2"
553        r = query(q).getresult()
554        self.assertEqual(len(r), 1)
555        r = r[0]
556        self.assertEqual(len(r), 1)
557        r = r[0]
558        self.assertIsInstance(r, int)
559        self.assertEqual(r, oid)
560        q = "insert into test_table select 3 union select 4 union select 5"
561        r = query(q)
562        self.assertIsInstance(r, str)
563        self.assertEqual(r, '3')
564        q = "update test_table set n=4 where n<5"
565        r = query(q)
566        self.assertIsInstance(r, str)
567        self.assertEqual(r, '4')
568        q = "delete from test_table"
569        r = query(q)
570        self.assertIsInstance(r, str)
571        self.assertEqual(r, '5')
572
573
574class TestUnicodeQueries(unittest.TestCase):
575    """Test unicode strings as queries via a basic pg connection."""
576
577    def setUp(self):
578        self.c = connect()
579        self.c.query('set client_encoding=utf8')
580
581    def tearDown(self):
582        self.c.close()
583
584    def testGetresulAscii(self):
585        result = u'Hello, world!'
586        q = u"select '%s'" % result
587        v = self.c.query(q).getresult()[0][0]
588        self.assertIsInstance(v, str)
589        self.assertEqual(v, result)
590
591    def testDictresulAscii(self):
592        result = u'Hello, world!'
593        q = u"select '%s' as greeting" % result
594        v = self.c.query(q).dictresult()[0]['greeting']
595        self.assertIsInstance(v, str)
596        self.assertEqual(v, result)
597
598    def testGetresultUtf8(self):
599        result = u'Hello, wörld & ЌОр!'
600        q = u"select '%s'" % result
601        if not unicode_strings:
602            result = result.encode('utf8')
603        # pass the query as unicode
604        try:
605            v = self.c.query(q).getresult()[0][0]
606        except pg.ProgrammingError:
607            self.skipTest("database does not support utf8")
608        self.assertIsInstance(v, str)
609        self.assertEqual(v, result)
610        q = q.encode('utf8')
611        # pass the query as bytes
612        v = self.c.query(q).getresult()[0][0]
613        self.assertIsInstance(v, str)
614        self.assertEqual(v, result)
615
616    def testDictresultUtf8(self):
617        result = u'Hello, wörld & ЌОр!'
618        q = u"select '%s' as greeting" % result
619        if not unicode_strings:
620            result = result.encode('utf8')
621        try:
622            v = self.c.query(q).dictresult()[0]['greeting']
623        except pg.ProgrammingError:
624            self.skipTest("database does not support utf8")
625        self.assertIsInstance(v, str)
626        self.assertEqual(v, result)
627        q = q.encode('utf8')
628        v = self.c.query(q).dictresult()[0]['greeting']
629        self.assertIsInstance(v, str)
630        self.assertEqual(v, result)
631
632    def testDictresultLatin1(self):
633        try:
634            self.c.query('set client_encoding=latin1')
635        except pg.ProgrammingError:
636            self.skipTest("database does not support latin1")
637        result = u'Hello, wörld!'
638        q = u"select '%s'" % result
639        if not unicode_strings:
640            result = result.encode('latin1')
641        v = self.c.query(q).getresult()[0][0]
642        self.assertIsInstance(v, str)
643        self.assertEqual(v, result)
644        q = q.encode('latin1')
645        v = self.c.query(q).getresult()[0][0]
646        self.assertIsInstance(v, str)
647        self.assertEqual(v, result)
648
649    def testDictresultLatin1(self):
650        try:
651            self.c.query('set client_encoding=latin1')
652        except pg.ProgrammingError:
653            self.skipTest("database does not support latin1")
654        result = u'Hello, wörld!'
655        q = u"select '%s' as greeting" % result
656        if not unicode_strings:
657            result = result.encode('latin1')
658        v = self.c.query(q).dictresult()[0]['greeting']
659        self.assertIsInstance(v, str)
660        self.assertEqual(v, result)
661        q = q.encode('latin1')
662        v = self.c.query(q).dictresult()[0]['greeting']
663        self.assertIsInstance(v, str)
664        self.assertEqual(v, result)
665
666    def testGetresultCyrillic(self):
667        try:
668            self.c.query('set client_encoding=iso_8859_5')
669        except pg.ProgrammingError:
670            self.skipTest("database does not support cyrillic")
671        result = u'Hello, ЌОр!'
672        q = u"select '%s'" % result
673        if not unicode_strings:
674            result = result.encode('cyrillic')
675        v = self.c.query(q).getresult()[0][0]
676        self.assertIsInstance(v, str)
677        self.assertEqual(v, result)
678        q = q.encode('cyrillic')
679        v = self.c.query(q).getresult()[0][0]
680        self.assertIsInstance(v, str)
681        self.assertEqual(v, result)
682
683    def testDictresultCyrillic(self):
684        try:
685            self.c.query('set client_encoding=iso_8859_5')
686        except pg.ProgrammingError:
687            self.skipTest("database does not support cyrillic")
688        result = u'Hello, ЌОр!'
689        q = u"select '%s' as greeting" % result
690        if not unicode_strings:
691            result = result.encode('cyrillic')
692        v = self.c.query(q).dictresult()[0]['greeting']
693        self.assertIsInstance(v, str)
694        self.assertEqual(v, result)
695        q = q.encode('cyrillic')
696        v = self.c.query(q).dictresult()[0]['greeting']
697        self.assertIsInstance(v, str)
698        self.assertEqual(v, result)
699
700    def testGetresultLatin9(self):
701        try:
702            self.c.query('set client_encoding=latin9')
703        except pg.ProgrammingError:
704            self.skipTest("database does not support latin9")
705        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
706        q = u"select '%s'" % result
707        if not unicode_strings:
708            result = result.encode('latin9')
709        v = self.c.query(q).getresult()[0][0]
710        self.assertIsInstance(v, str)
711        self.assertEqual(v, result)
712        q = q.encode('latin9')
713        v = self.c.query(q).getresult()[0][0]
714        self.assertIsInstance(v, str)
715        self.assertEqual(v, result)
716
717    def testDictresultLatin9(self):
718        try:
719            self.c.query('set client_encoding=latin9')
720        except pg.ProgrammingError:
721            self.skipTest("database does not support latin9")
722        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
723        q = u"select '%s' as menu" % result
724        if not unicode_strings:
725            result = result.encode('latin9')
726        v = self.c.query(q).dictresult()[0]['menu']
727        self.assertIsInstance(v, str)
728        self.assertEqual(v, result)
729        q = q.encode('latin9')
730        v = self.c.query(q).dictresult()[0]['menu']
731        self.assertIsInstance(v, str)
732        self.assertEqual(v, result)
733
734
735class TestParamQueries(unittest.TestCase):
736    """Test queries with parameters via a basic pg connection."""
737
738    def setUp(self):
739        self.c = connect()
740        self.c.query('set client_encoding=utf8')
741
742    def tearDown(self):
743        self.c.close()
744
745    def testQueryWithNoneParam(self):
746        self.assertRaises(TypeError, self.c.query, "select $1", None)
747        self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None)
748        self.assertEqual(self.c.query("select $1::integer", (None,)
749            ).getresult(), [(None,)])
750        self.assertEqual(self.c.query("select $1::text", [None]
751            ).getresult(), [(None,)])
752        self.assertEqual(self.c.query("select $1::text", [[None]]
753            ).getresult(), [(None,)])
754
755    def testQueryWithBoolParams(self, bool_enabled=None):
756        query = self.c.query
757        if bool_enabled is not None:
758            bool_enabled_default = pg.get_bool()
759            pg.set_bool(bool_enabled)
760        try:
761            bool_on = bool_enabled or bool_enabled is None
762            v_false, v_true = (False, True) if bool_on else 'ft'
763            r_false, r_true = [(v_false,)], [(v_true,)]
764            self.assertEqual(query("select false").getresult(), r_false)
765            self.assertEqual(query("select true").getresult(), r_true)
766            q = "select $1::bool"
767            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
768            self.assertEqual(query(q, ('f',)).getresult(), r_false)
769            self.assertEqual(query(q, ('t',)).getresult(), r_true)
770            self.assertEqual(query(q, ('false',)).getresult(), r_false)
771            self.assertEqual(query(q, ('true',)).getresult(), r_true)
772            self.assertEqual(query(q, ('n',)).getresult(), r_false)
773            self.assertEqual(query(q, ('y',)).getresult(), r_true)
774            self.assertEqual(query(q, (0,)).getresult(), r_false)
775            self.assertEqual(query(q, (1,)).getresult(), r_true)
776            self.assertEqual(query(q, (False,)).getresult(), r_false)
777            self.assertEqual(query(q, (True,)).getresult(), r_true)
778        finally:
779            if bool_enabled is not None:
780                pg.set_bool(bool_enabled_default)
781
782    def testQueryWithBoolParamsNotDefault(self):
783        self.testQueryWithBoolParams(bool_enabled=not pg.get_bool())
784
785    def testQueryWithIntParams(self):
786        query = self.c.query
787        self.assertEqual(query("select 1+1").getresult(), [(2,)])
788        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
789        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
790        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
791        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
792        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
793            [(Decimal('2'),)])
794        self.assertEqual(query("select 1, $1::integer", (2,)
795            ).getresult(), [(1, 2)])
796        self.assertEqual(query("select 1 union select $1::integer", (2,)
797            ).getresult(), [(1,), (2,)])
798        self.assertEqual(query("select $1::integer+$2", (1, 2)
799            ).getresult(), [(3,)])
800        self.assertEqual(query("select $1::integer+$2", [1, 2]
801            ).getresult(), [(3,)])
802        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))
803            ).getresult(), [(15,)])
804
805    def testQueryWithStrParams(self):
806        query = self.c.query
807        self.assertEqual(query("select $1||', world!'", ('Hello',)
808            ).getresult(), [('Hello, world!',)])
809        self.assertEqual(query("select $1||', world!'", ['Hello']
810            ).getresult(), [('Hello, world!',)])
811        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
812            ).getresult(), [('Hello, world!',)])
813        self.assertEqual(query("select $1::text", ('Hello, world!',)
814            ).getresult(), [('Hello, world!',)])
815        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
816            ).getresult(), [('Hello', 'world')])
817        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
818            ).getresult(), [('Hello', 'world')])
819        self.assertEqual(query("select $1::text union select $2::text",
820            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
821        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
822            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
823
824    def testQueryWithUnicodeParams(self):
825        query = self.c.query
826        try:
827            query('set client_encoding=utf8')
828            query("select 'wörld'").getresult()[0][0] == 'wörld'
829        except pg.ProgrammingError:
830            self.skipTest("database does not support utf8")
831        self.assertEqual(query("select $1||', '||$2||'!'",
832            ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)])
833
834    def testQueryWithUnicodeParamsLatin1(self):
835        query = self.c.query
836        try:
837            query('set client_encoding=latin1')
838            query("select 'wörld'").getresult()[0][0] == 'wörld'
839        except pg.ProgrammingError:
840            self.skipTest("database does not support latin1")
841        r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult()
842        if unicode_strings:
843            self.assertEqual(r, [('Hello, wörld!',)])
844        else:
845            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
846        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
847            ('Hello', u'ЌОр'))
848        query('set client_encoding=iso_8859_1')
849        r = query("select $1||', '||$2||'!'",
850            ('Hello', u'wörld')).getresult()
851        if unicode_strings:
852            self.assertEqual(r, [('Hello, wörld!',)])
853        else:
854            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
855        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
856            ('Hello', u'ЌОр'))
857        query('set client_encoding=sql_ascii')
858        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
859            ('Hello', u'wörld'))
860
861    def testQueryWithUnicodeParamsCyrillic(self):
862        query = self.c.query
863        try:
864            query('set client_encoding=iso_8859_5')
865            query("select 'ЌОр'").getresult()[0][0] == 'ЌОр'
866        except pg.ProgrammingError:
867            self.skipTest("database does not support cyrillic")
868        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
869            ('Hello', u'wörld'))
870        r = query("select $1||', '||$2||'!'",
871            ('Hello', u'ЌОр')).getresult()
872        if unicode_strings:
873            self.assertEqual(r, [('Hello, ЌОр!',)])
874        else:
875            self.assertEqual(r, [(u'Hello, ЌОр!'.encode('cyrillic'),)])
876        query('set client_encoding=sql_ascii')
877        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
878            ('Hello', u'ЌОр!'))
879
880    def testQueryWithMixedParams(self):
881        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
882            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
883        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
884            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
885
886    def testQueryWithDuplicateParams(self):
887        self.assertRaises(pg.ProgrammingError,
888            self.c.query, "select $1+$1", (1,))
889        self.assertRaises(pg.ProgrammingError,
890            self.c.query, "select $1+$1", (1, 2))
891
892    def testQueryWithZeroParams(self):
893        self.assertEqual(self.c.query("select 1+1", []
894            ).getresult(), [(2,)])
895
896    def testQueryWithGarbage(self):
897        garbage = r"'\{}+()-#[]oo324"
898        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
899            ).dictresult(), [{'garbage': garbage}])
900
901
902class TestQueryResultTypes(unittest.TestCase):
903    """Test proper result types via a basic pg connection."""
904
905    def setUp(self):
906        self.c = connect()
907        self.c.query('set client_encoding=utf8')
908        self.c.query("set datestyle='ISO,YMD'")
909
910    def tearDown(self):
911        self.c.close()
912
913    def assert_proper_cast(self, value, pgtype, pytype):
914        q = 'select $1::%s' % (pgtype,)
915        r = self.c.query(q, (value,)).getresult()[0][0]
916        self.assertIsInstance(r, pytype)
917        if isinstance(value, str):
918            if not value or ' ' in value or '{' in value:
919                value = '"%s"' % value
920        value = '{%s}' % value
921        r = self.c.query(q + '[]', (value,)).getresult()[0][0]
922        if pgtype.startswith(('date', 'time', 'interval')):
923            # arrays of these are casted by the DB wrapper only
924            self.assertEqual(r, value)
925        else:
926            self.assertIsInstance(r, list)
927            self.assertEqual(len(r), 1)
928            self.assertIsInstance(r[0], pytype)
929
930    def testInt(self):
931        self.assert_proper_cast(0, 'int', int)
932        self.assert_proper_cast(0, 'smallint', int)
933        self.assert_proper_cast(0, 'oid', int)
934        self.assert_proper_cast(0, 'cid', int)
935        self.assert_proper_cast(0, 'xid', int)
936
937    def testLong(self):
938        self.assert_proper_cast(0, 'bigint', long)
939
940    def testFloat(self):
941        self.assert_proper_cast(0, 'float', float)
942        self.assert_proper_cast(0, 'real', float)
943        self.assert_proper_cast(0, 'double', float)
944        self.assert_proper_cast(0, 'double precision', float)
945        self.assert_proper_cast('infinity', 'float', float)
946
947    def testFloat(self):
948        decimal = pg.get_decimal()
949        self.assert_proper_cast(decimal(0), 'numeric', decimal)
950        self.assert_proper_cast(decimal(0), 'decimal', decimal)
951
952    def testMoney(self):
953        decimal = pg.get_decimal()
954        self.assert_proper_cast(decimal('0'), 'money', decimal)
955
956    def testBool(self):
957        bool_type = bool if pg.get_bool() else str
958        self.assert_proper_cast('f', 'bool', bool_type)
959
960    def testDate(self):
961        self.assert_proper_cast('1956-01-31', 'date', str)
962        self.assert_proper_cast('10:20:30', 'interval', str)
963        self.assert_proper_cast('08:42:15', 'time', str)
964        self.assert_proper_cast('08:42:15+00', 'timetz', str)
965        self.assert_proper_cast('1956-01-31 08:42:15', 'timestamp', str)
966        self.assert_proper_cast('1956-01-31 08:42:15+00', 'timestamptz', str)
967
968    def testText(self):
969        self.assert_proper_cast('', 'text', str)
970        self.assert_proper_cast('', 'char', str)
971        self.assert_proper_cast('', 'bpchar', str)
972        self.assert_proper_cast('', 'varchar', str)
973
974    def testBytea(self):
975        self.assert_proper_cast('', 'bytea', bytes)
976
977    def testJson(self):
978        self.assert_proper_cast('{}', 'json', dict)
979
980
981class TestInserttable(unittest.TestCase):
982    """Test inserttable method."""
983
984    cls_set_up = False
985
986    @classmethod
987    def setUpClass(cls):
988        c = connect()
989        c.query("drop table if exists test cascade")
990        c.query("create table test ("
991            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
992            "d numeric, f4 real, f8 double precision, m money,"
993            "c char(1), v4 varchar(4), c4 char(4), t text)")
994        # Check whether the test database uses SQL_ASCII - this means
995        # that it does not consider encoding when calculating lengths.
996        c.query("set client_encoding=utf8")
997        cls.has_encoding = c.query(
998            "select length('À') - length('a')").getresult()[0][0] == 0
999        c.close()
1000        cls.cls_set_up = True
1001
1002    @classmethod
1003    def tearDownClass(cls):
1004        c = connect()
1005        c.query("drop table test cascade")
1006        c.close()
1007
1008    def setUp(self):
1009        self.assertTrue(self.cls_set_up)
1010        self.c = connect()
1011        self.c.query("set client_encoding=utf8")
1012        self.c.query("set datestyle='ISO,YMD'")
1013        self.c.query("set lc_monetary='C'")
1014
1015    def tearDown(self):
1016        self.c.query("truncate table test")
1017        self.c.close()
1018
1019    data = [
1020        (-1, -1, long(-1), True, '1492-10-12', '08:30:00',
1021            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
1022        (0, 0, long(0), False, '1607-04-14', '09:00:00',
1023            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
1024        (1, 1, long(1), True, '1801-03-04', '03:45:00',
1025            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
1026        (2, 2, long(2), False, '1903-12-17', '11:22:00',
1027            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
1028
1029    @classmethod
1030    def db_len(cls, s, encoding):
1031        if cls.has_encoding:
1032            s = s if isinstance(s, unicode) else s.decode(encoding)
1033        else:
1034            s = s.encode(encoding) if isinstance(s, unicode) else s
1035        return len(s)
1036
1037    def get_back(self, encoding='utf-8'):
1038        """Convert boolean and decimal values back."""
1039        data = []
1040        for row in self.c.query("select * from test order by 1").getresult():
1041            self.assertIsInstance(row, tuple)
1042            row = list(row)
1043            if row[0] is not None:  # smallint
1044                self.assertIsInstance(row[0], int)
1045            if row[1] is not None:  # integer
1046                self.assertIsInstance(row[1], int)
1047            if row[2] is not None:  # bigint
1048                self.assertIsInstance(row[2], long)
1049            if row[3] is not None:  # boolean
1050                self.assertIsInstance(row[3], bool)
1051            if row[4] is not None:  # date
1052                self.assertIsInstance(row[4], str)
1053                self.assertTrue(row[4].replace('-', '').isdigit())
1054            if row[5] is not None:  # time
1055                self.assertIsInstance(row[5], str)
1056                self.assertTrue(row[5].replace(':', '').isdigit())
1057            if row[6] is not None:  # numeric
1058                self.assertIsInstance(row[6], Decimal)
1059                row[6] = float(row[6])
1060            if row[7] is not None:  # real
1061                self.assertIsInstance(row[7], float)
1062            if row[8] is not None:  # double precision
1063                self.assertIsInstance(row[8], float)
1064                row[8] = float(row[8])
1065            if row[9] is not None:  # money
1066                self.assertIsInstance(row[9], Decimal)
1067                row[9] = str(float(row[9]))
1068            if row[10] is not None:  # char(1)
1069                self.assertIsInstance(row[10], str)
1070                self.assertEqual(self.db_len(row[10], encoding), 1)
1071            if row[11] is not None:  # varchar(4)
1072                self.assertIsInstance(row[11], str)
1073                self.assertLessEqual(self.db_len(row[11], encoding), 4)
1074            if row[12] is not None:  # char(4)
1075                self.assertIsInstance(row[12], str)
1076                self.assertEqual(self.db_len(row[12], encoding), 4)
1077                row[12] = row[12].rstrip()
1078            if row[13] is not None:  # text
1079                self.assertIsInstance(row[13], str)
1080            row = tuple(row)
1081            data.append(row)
1082        return data
1083
1084    def testInserttable1Row(self):
1085        data = self.data[2:3]
1086        self.c.inserttable('test', data)
1087        self.assertEqual(self.get_back(), data)
1088
1089    def testInserttable4Rows(self):
1090        data = self.data
1091        self.c.inserttable('test', data)
1092        self.assertEqual(self.get_back(), data)
1093
1094    def testInserttableMultipleRows(self):
1095        num_rows = 100
1096        data = self.data[2:3] * num_rows
1097        self.c.inserttable('test', data)
1098        r = self.c.query("select count(*) from test").getresult()[0][0]
1099        self.assertEqual(r, num_rows)
1100
1101    def testInserttableMultipleCalls(self):
1102        num_rows = 10
1103        data = self.data[2:3]
1104        for _i in range(num_rows):
1105            self.c.inserttable('test', data)
1106        r = self.c.query("select count(*) from test").getresult()[0][0]
1107        self.assertEqual(r, num_rows)
1108
1109    def testInserttableNullValues(self):
1110        data = [(None,) * 14] * 100
1111        self.c.inserttable('test', data)
1112        self.assertEqual(self.get_back(), data)
1113
1114    def testInserttableMaxValues(self):
1115        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
1116            True, '2999-12-31', '11:59:59', 1e99,
1117            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
1118            "1", "1234", "1234", "1234" * 100)]
1119        self.c.inserttable('test', data)
1120        self.assertEqual(self.get_back(), data)
1121
1122    def testInserttableByteValues(self):
1123        try:
1124            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1125        except pg.ProgrammingError:
1126            self.skipTest("database does not support utf8")
1127        # non-ascii chars do not fit in char(1) when there is no encoding
1128        c = u'€' if self.has_encoding else u'$'
1129        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1130            0.0, 0.0, 0.0, u'0.0',
1131            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1132        row_bytes = tuple(s.encode('utf-8')
1133            if isinstance(s, unicode) else s for s in row_unicode)
1134        data = [row_bytes] * 2
1135        self.c.inserttable('test', data)
1136        if unicode_strings:
1137            data = [row_unicode] * 2
1138        self.assertEqual(self.get_back(), data)
1139
1140    def testInserttableUnicodeUtf8(self):
1141        try:
1142            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1143        except pg.ProgrammingError:
1144            self.skipTest("database does not support utf8")
1145        # non-ascii chars do not fit in char(1) when there is no encoding
1146        c = u'€' if self.has_encoding else u'$'
1147        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1148            0.0, 0.0, 0.0, u'0.0',
1149            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1150        data = [row_unicode] * 2
1151        self.c.inserttable('test', data)
1152        if not unicode_strings:
1153            row_bytes = tuple(s.encode('utf-8')
1154                if isinstance(s, unicode) else s for s in row_unicode)
1155            data = [row_bytes] * 2
1156        self.assertEqual(self.get_back(), data)
1157
1158    def testInserttableUnicodeLatin1(self):
1159
1160        try:
1161            self.c.query("set client_encoding=latin1")
1162            self.c.query("select 'Â¥'")
1163        except pg.ProgrammingError:
1164            self.skipTest("database does not support latin1")
1165        # non-ascii chars do not fit in char(1) when there is no encoding
1166        c = u'€' if self.has_encoding else u'$'
1167        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1168            0.0, 0.0, 0.0, u'0.0',
1169            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1170        data = [row_unicode]
1171        # cannot encode € sign with latin1 encoding
1172        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1173        row_unicode = tuple(s.replace(u'€', u'Â¥')
1174            if isinstance(s, unicode) else s for s in row_unicode)
1175        data = [row_unicode] * 2
1176        self.c.inserttable('test', data)
1177        if not unicode_strings:
1178            row_bytes = tuple(s.encode('latin1')
1179                if isinstance(s, unicode) else s for s in row_unicode)
1180            data = [row_bytes] * 2
1181        self.assertEqual(self.get_back('latin1'), data)
1182
1183    def testInserttableUnicodeLatin9(self):
1184        try:
1185            self.c.query("set client_encoding=latin9")
1186            self.c.query("select '€'")
1187        except pg.ProgrammingError:
1188            self.skipTest("database does not support latin9")
1189            return
1190        # non-ascii chars do not fit in char(1) when there is no encoding
1191        c = u'€' if self.has_encoding else u'$'
1192        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1193            0.0, 0.0, 0.0, u'0.0',
1194            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1195        data = [row_unicode] * 2
1196        self.c.inserttable('test', data)
1197        if not unicode_strings:
1198            row_bytes = tuple(s.encode('latin9')
1199                if isinstance(s, unicode) else s for s in row_unicode)
1200            data = [row_bytes] * 2
1201        self.assertEqual(self.get_back('latin9'), data)
1202
1203    def testInserttableNoEncoding(self):
1204        self.c.query("set client_encoding=sql_ascii")
1205        # non-ascii chars do not fit in char(1) when there is no encoding
1206        c = u'€' if self.has_encoding else u'$'
1207        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1208            0.0, 0.0, 0.0, u'0.0',
1209            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1210        data = [row_unicode]
1211        # cannot encode non-ascii unicode without a specific encoding
1212        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1213
1214
1215class TestDirectSocketAccess(unittest.TestCase):
1216    """Test copy command with direct socket access."""
1217
1218    cls_set_up = False
1219
1220    @classmethod
1221    def setUpClass(cls):
1222        c = connect()
1223        c.query("drop table if exists test cascade")
1224        c.query("create table test (i int, v varchar(16))")
1225        c.close()
1226        cls.cls_set_up = True
1227
1228    @classmethod
1229    def tearDownClass(cls):
1230        c = connect()
1231        c.query("drop table test cascade")
1232        c.close()
1233
1234    def setUp(self):
1235        self.assertTrue(self.cls_set_up)
1236        self.c = connect()
1237        self.c.query("set client_encoding=utf8")
1238
1239    def tearDown(self):
1240        self.c.query("truncate table test")
1241        self.c.close()
1242
1243    def testPutline(self):
1244        putline = self.c.putline
1245        query = self.c.query
1246        data = list(enumerate("apple pear plum cherry banana".split()))
1247        query("copy test from stdin")
1248        try:
1249            for i, v in data:
1250                putline("%d\t%s\n" % (i, v))
1251            putline("\\.\n")
1252        finally:
1253            self.c.endcopy()
1254        r = query("select * from test").getresult()
1255        self.assertEqual(r, data)
1256
1257    def testPutlineBytesAndUnicode(self):
1258        putline = self.c.putline
1259        query = self.c.query
1260        query("copy test from stdin")
1261        try:
1262            putline(u"47\tkÀse\n".encode('utf8'))
1263            putline("35\twÃŒrstel\n")
1264            putline(b"\\.\n")
1265        finally:
1266            self.c.endcopy()
1267        r = query("select * from test").getresult()
1268        self.assertEqual(r, [(47, 'kÀse'), (35, 'wÃŒrstel')])
1269
1270    def testGetline(self):
1271        getline = self.c.getline
1272        query = self.c.query
1273        data = list(enumerate("apple banana pear plum strawberry".split()))
1274        n = len(data)
1275        self.c.inserttable('test', data)
1276        query("copy test to stdout")
1277        try:
1278            for i in range(n + 2):
1279                v = getline()
1280                if i < n:
1281                    self.assertEqual(v, '%d\t%s' % data[i])
1282                elif i == n:
1283                    self.assertEqual(v, '\\.')
1284                else:
1285                    self.assertIsNone(v)
1286        finally:
1287            try:
1288                self.c.endcopy()
1289            except IOError:
1290                pass
1291
1292    def testGetlineBytesAndUnicode(self):
1293        getline = self.c.getline
1294        query = self.c.query
1295        data = [(54, u'kÀse'.encode('utf8')), (73, u'wÃŒrstel')]
1296        self.c.inserttable('test', data)
1297        query("copy test to stdout")
1298        try:
1299            v = getline()
1300            self.assertIsInstance(v, str)
1301            self.assertEqual(v, '54\tkÀse')
1302            v = getline()
1303            self.assertIsInstance(v, str)
1304            self.assertEqual(v, '73\twÃŒrstel')
1305            self.assertEqual(getline(), '\\.')
1306            self.assertIsNone(getline())
1307        finally:
1308            try:
1309                self.c.endcopy()
1310            except IOError:
1311                pass
1312
1313    def testParameterChecks(self):
1314        self.assertRaises(TypeError, self.c.putline)
1315        self.assertRaises(TypeError, self.c.getline, 'invalid')
1316        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
1317
1318
1319class TestNotificatons(unittest.TestCase):
1320    """Test notification support."""
1321
1322    def setUp(self):
1323        self.c = connect()
1324
1325    def tearDown(self):
1326        self.doCleanups()
1327        self.c.close()
1328
1329    def testGetNotify(self):
1330        getnotify = self.c.getnotify
1331        query = self.c.query
1332        self.assertIsNone(getnotify())
1333        query('listen test_notify')
1334        try:
1335            self.assertIsNone(self.c.getnotify())
1336            query("notify test_notify")
1337            r = getnotify()
1338            self.assertIsInstance(r, tuple)
1339            self.assertEqual(len(r), 3)
1340            self.assertIsInstance(r[0], str)
1341            self.assertIsInstance(r[1], int)
1342            self.assertIsInstance(r[2], str)
1343            self.assertEqual(r[0], 'test_notify')
1344            self.assertEqual(r[2], '')
1345            self.assertIsNone(self.c.getnotify())
1346            query("notify test_notify, 'test_payload'")
1347            r = getnotify()
1348            self.assertTrue(isinstance(r, tuple))
1349            self.assertEqual(len(r), 3)
1350            self.assertIsInstance(r[0], str)
1351            self.assertIsInstance(r[1], int)
1352            self.assertIsInstance(r[2], str)
1353            self.assertEqual(r[0], 'test_notify')
1354            self.assertEqual(r[2], 'test_payload')
1355            self.assertIsNone(getnotify())
1356        finally:
1357            query('unlisten test_notify')
1358
1359    def testGetNoticeReceiver(self):
1360        self.assertIsNone(self.c.get_notice_receiver())
1361
1362    def testSetNoticeReceiver(self):
1363        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
1364        self.assertRaises(TypeError, self.c.set_notice_receiver, 'invalid')
1365        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
1366        self.assertIsNone(self.c.set_notice_receiver(None))
1367
1368    def testSetAndGetNoticeReceiver(self):
1369        r = lambda notice: None
1370        self.assertIsNone(self.c.set_notice_receiver(r))
1371        self.assertIs(self.c.get_notice_receiver(), r)
1372        self.assertIsNone(self.c.set_notice_receiver(None))
1373        self.assertIsNone(self.c.get_notice_receiver())
1374
1375    def testNoticeReceiver(self):
1376        self.addCleanup(self.c.query, 'drop function bilbo_notice();')
1377        self.c.query('''create function bilbo_notice() returns void AS $$
1378            begin
1379                raise warning 'Bilbo was here!';
1380            end;
1381            $$ language plpgsql''')
1382        received = {}
1383
1384        def notice_receiver(notice):
1385            for attr in dir(notice):
1386                if attr.startswith('__'):
1387                    continue
1388                value = getattr(notice, attr)
1389                if isinstance(value, str):
1390                    value = value.replace('WARNUNG', 'WARNING')
1391                received[attr] = value
1392
1393        self.c.set_notice_receiver(notice_receiver)
1394        self.c.query('select bilbo_notice()')
1395        self.assertEqual(received, dict(
1396            pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
1397            severity='WARNING', primary='Bilbo was here!',
1398            detail=None, hint=None))
1399
1400
1401class TestConfigFunctions(unittest.TestCase):
1402    """Test the functions for changing default settings.
1403
1404    To test the effect of most of these functions, we need a database
1405    connection.  That's why they are covered in this test module.
1406    """
1407
1408    def setUp(self):
1409        self.c = connect()
1410        self.c.query("set client_encoding=utf8")
1411        self.c.query('set bytea_output=hex')
1412        self.c.query("set lc_monetary='C'")
1413
1414    def tearDown(self):
1415        self.c.close()
1416
1417    def testGetDecimalPoint(self):
1418        point = pg.get_decimal_point()
1419        # error if a parameter is passed
1420        self.assertRaises(TypeError, pg.get_decimal_point, point)
1421        self.assertIsInstance(point, str)
1422        self.assertEqual(point, '.')  # the default setting
1423        pg.set_decimal_point(',')
1424        try:
1425            r = pg.get_decimal_point()
1426        finally:
1427            pg.set_decimal_point(point)
1428        self.assertIsInstance(r, str)
1429        self.assertEqual(r, ',')
1430        pg.set_decimal_point("'")
1431        try:
1432            r = pg.get_decimal_point()
1433        finally:
1434            pg.set_decimal_point(point)
1435        self.assertIsInstance(r, str)
1436        self.assertEqual(r, "'")
1437        pg.set_decimal_point('')
1438        try:
1439            r = pg.get_decimal_point()
1440        finally:
1441            pg.set_decimal_point(point)
1442        self.assertIsNone(r)
1443        pg.set_decimal_point(None)
1444        try:
1445            r = pg.get_decimal_point()
1446        finally:
1447            pg.set_decimal_point(point)
1448        self.assertIsNone(r)
1449
1450    def testSetDecimalPoint(self):
1451        d = pg.Decimal
1452        point = pg.get_decimal_point()
1453        self.assertRaises(TypeError, pg.set_decimal_point)
1454        # error if decimal point is not a string
1455        self.assertRaises(TypeError, pg.set_decimal_point, 0)
1456        # error if more than one decimal point passed
1457        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
1458        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
1459        # error if decimal point is not a punctuation character
1460        self.assertRaises(TypeError, pg.set_decimal_point, '0')
1461        query = self.c.query
1462        # check that money values are interpreted as decimal values
1463        # only if decimal_point is set, and that the result is correct
1464        # only if it is set suitable for the current lc_monetary setting
1465        select_money = "select '34.25'::money"
1466        proper_money = d('34.25')
1467        bad_money = d('3425')
1468        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
1469        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
1470        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
1471        de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25',
1472            'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM')
1473        # first try with English localization (using the point)
1474        for lc in en_locales:
1475            try:
1476                query("set lc_monetary='%s'" % lc)
1477            except pg.ProgrammingError:
1478                pass
1479            else:
1480                break
1481        else:
1482            self.skipTest("cannot set English money locale")
1483        try:
1484            r = query(select_money)
1485        except pg.ProgrammingError:
1486            # this can happen if the currency signs cannot be
1487            # converted using the encoding of the test database
1488            self.skipTest("database does not support English money")
1489        pg.set_decimal_point(None)
1490        try:
1491            r = r.getresult()[0][0]
1492        finally:
1493            pg.set_decimal_point(point)
1494        self.assertIsInstance(r, str)
1495        self.assertIn(r, en_money)
1496        r = query(select_money)
1497        pg.set_decimal_point('')
1498        try:
1499            r = r.getresult()[0][0]
1500        finally:
1501            pg.set_decimal_point(point)
1502        self.assertIsInstance(r, str)
1503        self.assertIn(r, en_money)
1504        r = query(select_money)
1505        pg.set_decimal_point('.')
1506        try:
1507            r = r.getresult()[0][0]
1508        finally:
1509            pg.set_decimal_point(point)
1510        self.assertIsInstance(r, d)
1511        self.assertEqual(r, proper_money)
1512        r = query(select_money)
1513        pg.set_decimal_point(',')
1514        try:
1515            r = r.getresult()[0][0]
1516        finally:
1517            pg.set_decimal_point(point)
1518        self.assertIsInstance(r, d)
1519        self.assertEqual(r, bad_money)
1520        r = query(select_money)
1521        pg.set_decimal_point("'")
1522        try:
1523            r = r.getresult()[0][0]
1524        finally:
1525            pg.set_decimal_point(point)
1526        self.assertIsInstance(r, d)
1527        self.assertEqual(r, bad_money)
1528        # then try with German localization (using the comma)
1529        for lc in de_locales:
1530            try:
1531                query("set lc_monetary='%s'" % lc)
1532            except pg.ProgrammingError:
1533                pass
1534            else:
1535                break
1536        else:
1537            self.skipTest("cannot set German money locale")
1538        select_money = select_money.replace('.', ',')
1539        try:
1540            r = query(select_money)
1541        except pg.ProgrammingError:
1542            self.skipTest("database does not support English money")
1543        pg.set_decimal_point(None)
1544        try:
1545            r = r.getresult()[0][0]
1546        finally:
1547            pg.set_decimal_point(point)
1548        self.assertIsInstance(r, str)
1549        self.assertIn(r, de_money)
1550        r = query(select_money)
1551        pg.set_decimal_point('')
1552        try:
1553            r = r.getresult()[0][0]
1554        finally:
1555            pg.set_decimal_point(point)
1556        self.assertIsInstance(r, str)
1557        self.assertIn(r, de_money)
1558        r = query(select_money)
1559        pg.set_decimal_point(',')
1560        try:
1561            r = r.getresult()[0][0]
1562        finally:
1563            pg.set_decimal_point(point)
1564        self.assertIsInstance(r, d)
1565        self.assertEqual(r, proper_money)
1566        r = query(select_money)
1567        pg.set_decimal_point('.')
1568        try:
1569            r = r.getresult()[0][0]
1570        finally:
1571            pg.set_decimal_point(point)
1572        self.assertEqual(r, bad_money)
1573        r = query(select_money)
1574        pg.set_decimal_point("'")
1575        try:
1576            r = r.getresult()[0][0]
1577        finally:
1578            pg.set_decimal_point(point)
1579        self.assertEqual(r, bad_money)
1580
1581    def testGetDecimal(self):
1582        decimal_class = pg.get_decimal()
1583        # error if a parameter is passed
1584        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
1585        self.assertIs(decimal_class, pg.Decimal)  # the default setting
1586        pg.set_decimal(int)
1587        try:
1588            r = pg.get_decimal()
1589        finally:
1590            pg.set_decimal(decimal_class)
1591        self.assertIs(r, int)
1592        r = pg.get_decimal()
1593        self.assertIs(r, decimal_class)
1594
1595    def testSetDecimal(self):
1596        decimal_class = pg.get_decimal()
1597        # error if no parameter is passed
1598        self.assertRaises(TypeError, pg.set_decimal)
1599        query = self.c.query
1600        try:
1601            r = query("select 3425::numeric")
1602        except pg.ProgrammingError:
1603            self.skipTest('database does not support numeric')
1604        r = r.getresult()[0][0]
1605        self.assertIsInstance(r, decimal_class)
1606        self.assertEqual(r, decimal_class('3425'))
1607        r = query("select 3425::numeric")
1608        pg.set_decimal(int)
1609        try:
1610            r = r.getresult()[0][0]
1611        finally:
1612            pg.set_decimal(decimal_class)
1613        self.assertNotIsInstance(r, decimal_class)
1614        self.assertIsInstance(r, int)
1615        self.assertEqual(r, int(3425))
1616
1617    def testGetBool(self):
1618        use_bool = pg.get_bool()
1619        # error if a parameter is passed
1620        self.assertRaises(TypeError, pg.get_bool, use_bool)
1621        self.assertIsInstance(use_bool, bool)
1622        self.assertIs(use_bool, True)  # the default setting
1623        pg.set_bool(False)
1624        try:
1625            r = pg.get_bool()
1626        finally:
1627            pg.set_bool(use_bool)
1628        self.assertIsInstance(r, bool)
1629        self.assertIs(r, False)
1630        pg.set_bool(True)
1631        try:
1632            r = pg.get_bool()
1633        finally:
1634            pg.set_bool(use_bool)
1635        self.assertIsInstance(r, bool)
1636        self.assertIs(r, True)
1637        pg.set_bool(0)
1638        try:
1639            r = pg.get_bool()
1640        finally:
1641            pg.set_bool(use_bool)
1642        self.assertIsInstance(r, bool)
1643        self.assertIs(r, False)
1644        pg.set_bool(1)
1645        try:
1646            r = pg.get_bool()
1647        finally:
1648            pg.set_bool(use_bool)
1649        self.assertIsInstance(r, bool)
1650        self.assertIs(r, True)
1651
1652    def testSetBool(self):
1653        use_bool = pg.get_bool()
1654        # error if no parameter is passed
1655        self.assertRaises(TypeError, pg.set_bool)
1656        query = self.c.query
1657        try:
1658            r = query("select true::bool")
1659        except pg.ProgrammingError:
1660            self.skipTest('database does not support bool')
1661        r = r.getresult()[0][0]
1662        self.assertIsInstance(r, bool)
1663        self.assertEqual(r, True)
1664        r = query("select true::bool")
1665        pg.set_bool(False)
1666        try:
1667            r = r.getresult()[0][0]
1668        finally:
1669            pg.set_bool(use_bool)
1670        self.assertIsInstance(r, str)
1671        self.assertIs(r, 't')
1672        r = query("select true::bool")
1673        pg.set_bool(True)
1674        try:
1675            r = r.getresult()[0][0]
1676        finally:
1677            pg.set_bool(use_bool)
1678        self.assertIsInstance(r, bool)
1679        self.assertIs(r, True)
1680
1681    def testGetByteEscaped(self):
1682        bytea_escaped = pg.get_bytea_escaped()
1683        # error if a parameter is passed
1684        self.assertRaises(TypeError, pg.get_bytea_escaped, bytea_escaped)
1685        self.assertIsInstance(bytea_escaped, bool)
1686        self.assertIs(bytea_escaped, False)  # the default setting
1687        pg.set_bytea_escaped(True)
1688        try:
1689            r = pg.get_bytea_escaped()
1690        finally:
1691            pg.set_bytea_escaped(bytea_escaped)
1692        self.assertIsInstance(r, bool)
1693        self.assertIs(r, True)
1694        pg.set_bytea_escaped(False)
1695        try:
1696            r = pg.get_bytea_escaped()
1697        finally:
1698            pg.set_bytea_escaped(bytea_escaped)
1699        self.assertIsInstance(r, bool)
1700        self.assertIs(r, False)
1701        pg.set_bytea_escaped(1)
1702        try:
1703            r = pg.get_bytea_escaped()
1704        finally:
1705            pg.set_bytea_escaped(bytea_escaped)
1706        self.assertIsInstance(r, bool)
1707        self.assertIs(r, True)
1708        pg.set_bytea_escaped(0)
1709        try:
1710            r = pg.get_bytea_escaped()
1711        finally:
1712            pg.set_bytea_escaped(bytea_escaped)
1713        self.assertIsInstance(r, bool)
1714        self.assertIs(r, False)
1715
1716    def testSetByteaEscaped(self):
1717        bytea_escaped = pg.get_bytea_escaped()
1718        # error if no parameter is passed
1719        self.assertRaises(TypeError, pg.set_bytea_escaped)
1720        query = self.c.query
1721        try:
1722            r = query("select 'data'::bytea")
1723        except pg.ProgrammingError:
1724            self.skipTest('database does not support bytea')
1725        r = r.getresult()[0][0]
1726        self.assertIsInstance(r, bytes)
1727        self.assertEqual(r, b'data')
1728        r = query("select 'data'::bytea")
1729        pg.set_bytea_escaped(True)
1730        try:
1731            r = r.getresult()[0][0]
1732        finally:
1733            pg.set_bytea_escaped(bytea_escaped)
1734        self.assertIsInstance(r, str)
1735        self.assertEqual(r, '\\x64617461')
1736        r = query("select 'data'::bytea")
1737        pg.set_bytea_escaped(False)
1738        try:
1739            r = r.getresult()[0][0]
1740        finally:
1741            pg.set_bytea_escaped(bytea_escaped)
1742        self.assertIsInstance(r, bytes)
1743        self.assertEqual(r, b'data')
1744
1745    def testGetNamedresult(self):
1746        namedresult = pg.get_namedresult()
1747        # error if a parameter is passed
1748        self.assertRaises(TypeError, pg.get_namedresult, namedresult)
1749        self.assertIs(namedresult, pg._namedresult)  # the default setting
1750
1751    def testSetNamedresult(self):
1752        namedresult = pg.get_namedresult()
1753        self.assertTrue(callable(namedresult))
1754
1755        query = self.c.query
1756
1757        r = query("select 1 as x, 2 as y").namedresult()[0]
1758        self.assertIsInstance(r, tuple)
1759        self.assertEqual(r, (1, 2))
1760        self.assertIsNot(type(r), tuple)
1761        self.assertEqual(r._fields, ('x', 'y'))
1762        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1763        self.assertEqual(r.__class__.__name__, 'Row')
1764
1765        def listresult(q):
1766            return [list(row) for row in q.getresult()]
1767
1768        pg.set_namedresult(listresult)
1769        try:
1770            r = pg.get_namedresult()
1771            self.assertIs(r, listresult)
1772            r = query("select 1 as x, 2 as y").namedresult()[0]
1773            self.assertIsInstance(r, list)
1774            self.assertEqual(r, [1, 2])
1775            self.assertIsNot(type(r), tuple)
1776            self.assertFalse(hasattr(r, '_fields'))
1777            self.assertNotEqual(r.__class__.__name__, 'Row')
1778        finally:
1779            pg.set_namedresult(namedresult)
1780
1781        r = pg.get_namedresult()
1782        self.assertIs(r, namedresult)
1783
1784
1785class TestStandaloneEscapeFunctions(unittest.TestCase):
1786    """Test pg escape functions.
1787
1788    The libpq interface memorizes some parameters of the last opened
1789    connection that influence the result of these functions.  Therefore
1790    we need to open a connection with fixed parameters prior to testing
1791    in order to ensure that the tests always run under the same conditions.
1792    That's why these tests are included in this test module.
1793    """
1794
1795    cls_set_up = False
1796
1797    @classmethod
1798    def setUpClass(cls):
1799        query = connect().query
1800        query('set client_encoding=sql_ascii')
1801        query('set standard_conforming_strings=off')
1802        query('set bytea_output=escape')
1803        cls.cls_set_up = True
1804
1805    def testEscapeString(self):
1806        self.assertTrue(self.cls_set_up)
1807        f = pg.escape_string
1808        r = f(b'plain')
1809        self.assertIsInstance(r, bytes)
1810        self.assertEqual(r, b'plain')
1811        r = f(u'plain')
1812        self.assertIsInstance(r, unicode)
1813        self.assertEqual(r, u'plain')
1814        r = f(u"das is' kÀse".encode('utf-8'))
1815        self.assertIsInstance(r, bytes)
1816        self.assertEqual(r, u"das is'' kÀse".encode('utf-8'))
1817        r = f(u"that's cheesy")
1818        self.assertIsInstance(r, unicode)
1819        self.assertEqual(r, u"that''s cheesy")
1820        r = f(r"It's bad to have a \ inside.")
1821        self.assertEqual(r, r"It''s bad to have a \\ inside.")
1822
1823    def testEscapeBytea(self):
1824        self.assertTrue(self.cls_set_up)
1825        f = pg.escape_bytea
1826        r = f(b'plain')
1827        self.assertIsInstance(r, bytes)
1828        self.assertEqual(r, b'plain')
1829        r = f(u'plain')
1830        self.assertIsInstance(r, unicode)
1831        self.assertEqual(r, u'plain')
1832        r = f(u"das is' kÀse".encode('utf-8'))
1833        self.assertIsInstance(r, bytes)
1834        self.assertEqual(r, b"das is'' k\\\\303\\\\244se")
1835        r = f(u"that's cheesy")
1836        self.assertIsInstance(r, unicode)
1837        self.assertEqual(r, u"that''s cheesy")
1838        r = f(b'O\x00ps\xff!')
1839        self.assertEqual(r, b'O\\\\000ps\\\\377!')
1840
1841
1842if __name__ == '__main__':
1843    unittest.main()
Note: See TracBrowser for help on using the repository browser.