source: trunk/tests/test_classic_connection.py @ 998

Last change on this file since 998 was 998, checked in by cito, 6 months ago

Fix test for database with cyrillic charset

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 86.5 KB
Line 
1#!/usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17import threading
18import time
19import os
20
21from collections import namedtuple, Iterable
22from decimal import Decimal
23
24import pg  # the module under test
25
26# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
27# get our information from that.  Otherwise we use the defaults.
28# These tests should be run with various PostgreSQL versions and databases
29# created with different encodings and locales.  Particularly, make sure the
30# tests are running against databases created with both SQL_ASCII and UTF8.
31dbname = 'unittest'
32dbhost = None
33dbport = 5432
34
35try:
36    from .LOCAL_PyGreSQL import *
37except (ImportError, ValueError):
38    try:
39        from LOCAL_PyGreSQL import *
40    except ImportError:
41        pass
42
43try:  # noinspection PyUnresolvedReferences
44    long
45except NameError:  # Python >= 3.0
46    long = int
47
48try:  # noinspection PyUnresolvedReferences
49    unicode
50except NameError:  # Python >= 3.0
51    unicode = str
52
53unicode_strings = str is not bytes
54
55windows = os.name == 'nt'
56
57# There is a known a bug in libpq under Windows which can cause
58# the interface to crash when calling PQhost():
59do_not_ask_for_host = windows
60do_not_ask_for_host_reason = 'libpq issue on Windows'
61
62
63def connect():
64    """Create a basic pg connection to the test database."""
65    connection = pg.connect(dbname, dbhost, dbport)
66    connection.query("set client_min_messages=warning")
67    return connection
68
69
70class TestCanConnect(unittest.TestCase):
71    """Test whether a basic connection to PostgreSQL is possible."""
72
73    def testCanConnect(self):
74        try:
75            connection = connect()
76        except pg.Error as error:
77            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
78        try:
79            connection.close()
80        except pg.Error:
81            self.fail('Cannot close the database connection')
82
83
84class TestConnectObject(unittest.TestCase):
85    """Test existence of basic pg connection methods."""
86
87    def setUp(self):
88        self.connection = connect()
89
90    def tearDown(self):
91        try:
92            self.connection.close()
93        except pg.InternalError:
94            pass
95
96    def is_method(self, attribute):
97        """Check if given attribute on the connection is a method."""
98        if do_not_ask_for_host and attribute == 'host':
99            return False
100        return callable(getattr(self.connection, attribute))
101
102    def testClassName(self):
103        self.assertEqual(self.connection.__class__.__name__, 'Connection')
104
105    def testModuleName(self):
106        self.assertEqual(self.connection.__class__.__module__, 'pg')
107
108    def testStr(self):
109        r = str(self.connection)
110        self.assertTrue(r.startswith('<pg.Connection object'), r)
111
112    def testRepr(self):
113        r = repr(self.connection)
114        self.assertTrue(r.startswith('<pg.Connection object'), r)
115
116    def testAllConnectAttributes(self):
117        attributes = '''backend_pid db error host options port
118            protocol_version server_version socket
119            ssl_attributes ssl_in_use status user'''.split()
120        connection_attributes = [a for a in dir(self.connection)
121            if not a.startswith('__') and not self.is_method(a)]
122        self.assertEqual(attributes, connection_attributes)
123
124    def testAllConnectMethods(self):
125        methods = '''
126            cancel close date_format describe_prepared endcopy
127            escape_bytea escape_identifier escape_literal escape_string
128            fileno get_cast_hook get_notice_receiver getline getlo getnotify
129            inserttable locreate loimport parameter
130            prepare putline query query_prepared reset
131            set_cast_hook set_notice_receiver source transaction
132            '''.split()
133        connection_methods = [a for a in dir(self.connection)
134            if not a.startswith('__') and self.is_method(a)]
135        self.assertEqual(methods, connection_methods)
136
137    def testAttributeDb(self):
138        self.assertEqual(self.connection.db, dbname)
139
140    def testAttributeError(self):
141        error = self.connection.error
142        self.assertTrue(not error or 'krb5_' in error)
143
144    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
145    def testAttributeHost(self):
146        if dbhost and not dbhost.startswith('/'):
147            host = dbhost
148        else:
149            host = 'localhost'
150        self.assertIsInstance(self.connection.host, str)
151        self.assertEqual(self.connection.host, host)
152
153    def testAttributeOptions(self):
154        no_options = ''
155        self.assertEqual(self.connection.options, no_options)
156
157    def testAttributePort(self):
158        def_port = 5432
159        self.assertIsInstance(self.connection.port, int)
160        self.assertEqual(self.connection.port, dbport or def_port)
161
162    def testAttributeProtocolVersion(self):
163        protocol_version = self.connection.protocol_version
164        self.assertIsInstance(protocol_version, int)
165        self.assertTrue(2 <= protocol_version < 4)
166
167    def testAttributeServerVersion(self):
168        server_version = self.connection.server_version
169        self.assertIsInstance(server_version, int)
170        self.assertTrue(90000 <= server_version < 120000)
171
172    def testAttributeSocket(self):
173        socket = self.connection.socket
174        self.assertIsInstance(socket, int)
175        self.assertGreaterEqual(socket, 0)
176
177    def testAttributeBackendPid(self):
178        backend_pid = self.connection.backend_pid
179        self.assertIsInstance(backend_pid, int)
180        self.assertGreaterEqual(backend_pid, 1)
181
182    def testAttributeSslInUse(self):
183        ssl_in_use = self.connection.ssl_in_use
184        self.assertIsInstance(ssl_in_use, bool)
185        self.assertFalse(ssl_in_use)
186
187    def testAttributeSslAttributes(self):
188        ssl_attributes = self.connection.ssl_attributes
189        self.assertIsInstance(ssl_attributes, dict)
190        self.assertEqual(ssl_attributes, {
191            'cipher': None, 'compression': None, 'key_bits': None,
192            'library': None, 'protocol': None})
193
194    def testAttributeStatus(self):
195        status_ok = 1
196        self.assertIsInstance(self.connection.status, int)
197        self.assertEqual(self.connection.status, status_ok)
198
199    def testAttributeUser(self):
200        no_user = 'Deprecated facility'
201        user = self.connection.user
202        self.assertTrue(user)
203        self.assertIsInstance(user, str)
204        self.assertNotEqual(user, no_user)
205
206    def testMethodQuery(self):
207        query = self.connection.query
208        query("select 1+1")
209        query("select 1+$1", (1,))
210        query("select 1+$1+$2", (2, 3))
211        query("select 1+$1+$2", [2, 3])
212
213    def testMethodQueryEmpty(self):
214        self.assertRaises(ValueError, self.connection.query, '')
215
216    def testAllQueryMembers(self):
217        query = self.connection.query("select true where false")
218        members = '''
219            dictiter dictresult fieldname fieldnum getresult listfields
220            namediter namedresult ntuples one onedict onenamed onescalar
221            scalariter scalarresult single singledict singlenamed singlescalar
222            '''.split()
223        query_members = [a for a in dir(query)
224            if not a.startswith('__')
225            and a != 'next']  # this is only needed in Python 2
226        self.assertEqual(members, query_members)
227
228    def testMethodEndcopy(self):
229        try:
230            self.connection.endcopy()
231        except IOError:
232            pass
233
234    def testMethodClose(self):
235        self.connection.close()
236        try:
237            self.connection.reset()
238        except (pg.Error, TypeError):
239            pass
240        else:
241            self.fail('Reset should give an error for a closed connection')
242        self.assertRaises(pg.InternalError, self.connection.close)
243        try:
244            self.connection.query('select 1')
245        except (pg.Error, TypeError):
246            pass
247        else:
248            self.fail('Query should give an error for a closed connection')
249        self.connection = connect()
250
251    def testMethodReset(self):
252        query = self.connection.query
253        # check that client encoding gets reset
254        encoding = query('show client_encoding').getresult()[0][0].upper()
255        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
256        self.assertNotEqual(encoding, changed_encoding)
257        self.connection.query("set client_encoding=%s" % changed_encoding)
258        new_encoding = query('show client_encoding').getresult()[0][0].upper()
259        self.assertEqual(new_encoding, changed_encoding)
260        self.connection.reset()
261        new_encoding = query('show client_encoding').getresult()[0][0].upper()
262        self.assertNotEqual(new_encoding, changed_encoding)
263        self.assertEqual(new_encoding, encoding)
264
265    def testMethodCancel(self):
266        r = self.connection.cancel()
267        self.assertIsInstance(r, int)
268        self.assertEqual(r, 1)
269
270    def testCancelLongRunningThread(self):
271        errors = []
272
273        def sleep():
274            try:
275                self.connection.query('select pg_sleep(5)').getresult()
276            except pg.DatabaseError as error:
277                errors.append(str(error))
278
279        thread = threading.Thread(target=sleep)
280        t1 = time.time()
281        thread.start()  # run the query
282        while 1:  # make sure the query is really running
283            time.sleep(0.1)
284            if thread.is_alive() or time.time() - t1 > 5:
285                break
286        r = self.connection.cancel()  # cancel the running query
287        thread.join()  # wait for the thread to end
288        t2 = time.time()
289
290        self.assertIsInstance(r, int)
291        self.assertEqual(r, 1)  # return code should be 1
292        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
293        self.assertTrue(errors)
294
295    def testMethodFileNo(self):
296        r = self.connection.fileno()
297        self.assertIsInstance(r, int)
298        self.assertGreaterEqual(r, 0)
299
300    def testMethodTransaction(self):
301        transaction = self.connection.transaction
302        self.assertRaises(TypeError, transaction, None)
303        self.assertEqual(transaction(), pg.TRANS_IDLE)
304        self.connection.query('begin')
305        self.assertEqual(transaction(), pg.TRANS_INTRANS)
306        self.connection.query('rollback')
307        self.assertEqual(transaction(), pg.TRANS_IDLE)
308
309    def testMethodParameter(self):
310        parameter = self.connection.parameter
311        query = self.connection.query
312        self.assertRaises(TypeError, parameter)
313        r = parameter('this server setting does not exist')
314        self.assertIsNone(r)
315        s = query('show server_version').getresult()[0][0]
316        self.assertIsNotNone(s)
317        r = parameter('server_version')
318        self.assertEqual(r, s)
319        s = query('show server_encoding').getresult()[0][0]
320        self.assertIsNotNone(s)
321        r = parameter('server_encoding')
322        self.assertEqual(r, s)
323        s = query('show client_encoding').getresult()[0][0]
324        self.assertIsNotNone(s)
325        r = parameter('client_encoding')
326        self.assertEqual(r, s)
327        s = query('show server_encoding').getresult()[0][0]
328        self.assertIsNotNone(s)
329        r = parameter('server_encoding')
330        self.assertEqual(r, s)
331
332
333class TestSimpleQueries(unittest.TestCase):
334    """Test simple queries via a basic pg connection."""
335
336    def setUp(self):
337        self.c = connect()
338
339    def tearDown(self):
340        self.doCleanups()
341        self.c.close()
342
343    def testClassName(self):
344        r = self.c.query("select 1")
345        self.assertEqual(r.__class__.__name__, 'Query')
346
347    def testModuleName(self):
348        r = self.c.query("select 1")
349        self.assertEqual(r.__class__.__module__, 'pg')
350
351    def testStr(self):
352        q = ("select 1 as a, 'hello' as h, 'w' as world"
353            " union select 2, 'xyz', 'uvw'")
354        r = self.c.query(q)
355        self.assertEqual(str(r),
356            'a|  h  |world\n'
357            '-+-----+-----\n'
358            '1|hello|w    \n'
359            '2|xyz  |uvw  \n'
360            '(2 rows)')
361
362    def testRepr(self):
363        r = repr(self.c.query("select 1"))
364        self.assertTrue(r.startswith('<pg.Query object'), r)
365
366    def testSelect0(self):
367        q = "select 0"
368        self.c.query(q)
369
370    def testSelect0Semicolon(self):
371        q = "select 0;"
372        self.c.query(q)
373
374    def testSelectDotSemicolon(self):
375        q = "select .;"
376        self.assertRaises(pg.DatabaseError, self.c.query, q)
377
378    def testGetresult(self):
379        q = "select 0"
380        result = [(0,)]
381        r = self.c.query(q).getresult()
382        self.assertIsInstance(r, list)
383        v = r[0]
384        self.assertIsInstance(v, tuple)
385        self.assertIsInstance(v[0], int)
386        self.assertEqual(r, result)
387
388    def testGetresultLong(self):
389        q = "select 9876543210"
390        result = long(9876543210)
391        self.assertIsInstance(result, long)
392        v = self.c.query(q).getresult()[0][0]
393        self.assertIsInstance(v, long)
394        self.assertEqual(v, result)
395
396    def testGetresultDecimal(self):
397        q = "select 98765432109876543210"
398        result = Decimal(98765432109876543210)
399        v = self.c.query(q).getresult()[0][0]
400        self.assertIsInstance(v, Decimal)
401        self.assertEqual(v, result)
402
403    def testGetresultString(self):
404        result = 'Hello, world!'
405        q = "select '%s'" % result
406        v = self.c.query(q).getresult()[0][0]
407        self.assertIsInstance(v, str)
408        self.assertEqual(v, result)
409
410    def testDictresult(self):
411        q = "select 0 as alias0"
412        result = [{'alias0': 0}]
413        r = self.c.query(q).dictresult()
414        self.assertIsInstance(r, list)
415        v = r[0]
416        self.assertIsInstance(v, dict)
417        self.assertIsInstance(v['alias0'], int)
418        self.assertEqual(r, result)
419
420    def testDictresultLong(self):
421        q = "select 9876543210 as longjohnsilver"
422        result = long(9876543210)
423        self.assertIsInstance(result, long)
424        v = self.c.query(q).dictresult()[0]['longjohnsilver']
425        self.assertIsInstance(v, long)
426        self.assertEqual(v, result)
427
428    def testDictresultDecimal(self):
429        q = "select 98765432109876543210 as longjohnsilver"
430        result = Decimal(98765432109876543210)
431        v = self.c.query(q).dictresult()[0]['longjohnsilver']
432        self.assertIsInstance(v, Decimal)
433        self.assertEqual(v, result)
434
435    def testDictresultString(self):
436        result = 'Hello, world!'
437        q = "select '%s' as greeting" % result
438        v = self.c.query(q).dictresult()[0]['greeting']
439        self.assertIsInstance(v, str)
440        self.assertEqual(v, result)
441
442    def testNamedresult(self):
443        q = "select 0 as alias0"
444        result = [(0,)]
445        r = self.c.query(q).namedresult()
446        self.assertEqual(r, result)
447        v = r[0]
448        self.assertEqual(v._fields, ('alias0',))
449        self.assertEqual(v.alias0, 0)
450
451    def testNamedresultWithGoodFieldnames(self):
452        q = 'select 1 as snake_case_alias, 2 as "CamelCaseAlias"'
453        result = [(1, 2)]
454        r = self.c.query(q).namedresult()
455        self.assertEqual(r, result)
456        v = r[0]
457        self.assertEqual(v._fields, ('snake_case_alias', 'CamelCaseAlias'))
458
459    def testNamedresultWithBadFieldnames(self):
460        try:
461            r = namedtuple('Bad', ['?'] * 6, rename=True)
462        except TypeError:  # Python 2.6 or 3.0
463            fields = tuple('column_%d' % n for n in range(6))
464        else:
465            fields = r._fields
466        q = ('select 3 as "0alias", 4 as _alias, 5 as "alias$", 6 as "alias?",'
467            ' 7 as "kebap-case-alias", 8 as break, 9 as and_a_good_one')
468        result = [tuple(range(3, 10))]
469        r = self.c.query(q).namedresult()
470        self.assertEqual(r, result)
471        v = r[0]
472        self.assertEqual(v._fields[:6], fields)
473        self.assertEqual(v._fields[6], 'and_a_good_one')
474
475    def testGet3Cols(self):
476        q = "select 1,2,3"
477        result = [(1, 2, 3)]
478        r = self.c.query(q).getresult()
479        self.assertEqual(r, result)
480
481    def testGet3DictCols(self):
482        q = "select 1 as a,2 as b,3 as c"
483        result = [dict(a=1, b=2, c=3)]
484        r = self.c.query(q).dictresult()
485        self.assertEqual(r, result)
486
487    def testGet3NamedCols(self):
488        q = "select 1 as a,2 as b,3 as c"
489        result = [(1, 2, 3)]
490        r = self.c.query(q).namedresult()
491        self.assertEqual(r, result)
492        v = r[0]
493        self.assertEqual(v._fields, ('a', 'b', 'c'))
494        self.assertEqual(v.b, 2)
495
496    def testGet3Rows(self):
497        q = "select 3 union select 1 union select 2 order by 1"
498        result = [(1,), (2,), (3,)]
499        r = self.c.query(q).getresult()
500        self.assertEqual(r, result)
501
502    def testGet3DictRows(self):
503        q = ("select 3 as alias3"
504            " union select 1 union select 2 order by 1")
505        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
506        r = self.c.query(q).dictresult()
507        self.assertEqual(r, result)
508
509    def testGet3NamedRows(self):
510        q = ("select 3 as alias3"
511            " union select 1 union select 2 order by 1")
512        result = [(1,), (2,), (3,)]
513        r = self.c.query(q).namedresult()
514        self.assertEqual(r, result)
515        for v in r:
516            self.assertEqual(v._fields, ('alias3',))
517
518    def testDictresultNames(self):
519        q = "select 'MixedCase' as MixedCaseAlias"
520        result = [{'mixedcasealias': 'MixedCase'}]
521        r = self.c.query(q).dictresult()
522        self.assertEqual(r, result)
523        q = "select 'MixedCase' as \"MixedCaseAlias\""
524        result = [{'MixedCaseAlias': 'MixedCase'}]
525        r = self.c.query(q).dictresult()
526        self.assertEqual(r, result)
527
528    def testNamedresultNames(self):
529        q = "select 'MixedCase' as MixedCaseAlias"
530        result = [('MixedCase',)]
531        r = self.c.query(q).namedresult()
532        self.assertEqual(r, result)
533        v = r[0]
534        self.assertEqual(v._fields, ('mixedcasealias',))
535        self.assertEqual(v.mixedcasealias, 'MixedCase')
536        q = "select 'MixedCase' as \"MixedCaseAlias\""
537        r = self.c.query(q).namedresult()
538        self.assertEqual(r, result)
539        v = r[0]
540        self.assertEqual(v._fields, ('MixedCaseAlias',))
541        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
542
543    def testBigGetresult(self):
544        num_cols = 100
545        num_rows = 100
546        q = "select " + ','.join(map(str, range(num_cols)))
547        q = ' union all '.join((q,) * num_rows)
548        r = self.c.query(q).getresult()
549        result = [tuple(range(num_cols))] * num_rows
550        self.assertEqual(r, result)
551
552    def testListfields(self):
553        q = ('select 0 as a, 0 as b, 0 as c,'
554            ' 0 as c, 0 as b, 0 as a,'
555            ' 0 as lowercase, 0 as UPPERCASE,'
556            ' 0 as MixedCase, 0 as "MixedCase",'
557            ' 0 as a_long_name_with_underscores,'
558            ' 0 as "A long name with Blanks"')
559        r = self.c.query(q).listfields()
560        self.assertIsInstance(r, tuple)
561        result = ('a', 'b', 'c', 'c', 'b', 'a',
562            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
563            'a_long_name_with_underscores',
564            'A long name with Blanks')
565        self.assertEqual(r, result)
566
567    def testFieldname(self):
568        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
569        r = self.c.query(q).fieldname(2)
570        self.assertEqual(r, 'x')
571        r = self.c.query(q).fieldname(3)
572        self.assertEqual(r, 'y')
573
574    def testFieldnum(self):
575        q = "select 1 as x"
576        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
577        q = "select 1 as x"
578        r = self.c.query(q).fieldnum('x')
579        self.assertIsInstance(r, int)
580        self.assertEqual(r, 0)
581        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
582        r = self.c.query(q).fieldnum('x')
583        self.assertIsInstance(r, int)
584        self.assertEqual(r, 2)
585        r = self.c.query(q).fieldnum('y')
586        self.assertIsInstance(r, int)
587        self.assertEqual(r, 3)
588
589    def testNtuples(self):  # deprecated
590        q = "select 1 where false"
591        r = self.c.query(q).ntuples()
592        self.assertIsInstance(r, int)
593        self.assertEqual(r, 0)
594        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
595            " union select 5 as a, 6 as b, 7 as c, 8 as d")
596        r = self.c.query(q).ntuples()
597        self.assertIsInstance(r, int)
598        self.assertEqual(r, 2)
599        q = ("select 1 union select 2 union select 3"
600            " union select 4 union select 5 union select 6")
601        r = self.c.query(q).ntuples()
602        self.assertIsInstance(r, int)
603        self.assertEqual(r, 6)
604
605    def testLen(self):
606        q = "select 1 where false"
607        self.assertEqual(len(self.c.query(q)), 0)
608        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
609            " union select 5 as a, 6 as b, 7 as c, 8 as d")
610        self.assertEqual(len(self.c.query(q)), 2)
611        q = ("select 1 union select 2 union select 3"
612            " union select 4 union select 5 union select 6")
613        self.assertEqual(len(self.c.query(q)), 6)
614
615    def testQuery(self):
616        query = self.c.query
617        query("drop table if exists test_table")
618        self.addCleanup(query, "drop table test_table")
619        q = "create table test_table (n integer) with oids"
620        r = query(q)
621        self.assertIsNone(r)
622        q = "insert into test_table values (1)"
623        r = query(q)
624        self.assertIsInstance(r, int)
625        q = "insert into test_table select 2"
626        r = query(q)
627        self.assertIsInstance(r, int)
628        oid = r
629        q = "select oid from test_table where n=2"
630        r = query(q).getresult()
631        self.assertEqual(len(r), 1)
632        r = r[0]
633        self.assertEqual(len(r), 1)
634        r = r[0]
635        self.assertIsInstance(r, int)
636        self.assertEqual(r, oid)
637        q = "insert into test_table select 3 union select 4 union select 5"
638        r = query(q)
639        self.assertIsInstance(r, str)
640        self.assertEqual(r, '3')
641        q = "update test_table set n=4 where n<5"
642        r = query(q)
643        self.assertIsInstance(r, str)
644        self.assertEqual(r, '4')
645        q = "delete from test_table"
646        r = query(q)
647        self.assertIsInstance(r, str)
648        self.assertEqual(r, '5')
649
650
651class TestUnicodeQueries(unittest.TestCase):
652    """Test unicode strings as queries via a basic pg connection."""
653
654    def setUp(self):
655        self.c = connect()
656        self.c.query('set client_encoding=utf8')
657
658    def tearDown(self):
659        self.c.close()
660
661    def testGetresulAscii(self):
662        result = u'Hello, world!'
663        q = u"select '%s'" % result
664        v = self.c.query(q).getresult()[0][0]
665        self.assertIsInstance(v, str)
666        self.assertEqual(v, result)
667
668    def testDictresulAscii(self):
669        result = u'Hello, world!'
670        q = u"select '%s' as greeting" % result
671        v = self.c.query(q).dictresult()[0]['greeting']
672        self.assertIsInstance(v, str)
673        self.assertEqual(v, result)
674
675    def testGetresultUtf8(self):
676        result = u'Hello, wörld & ЌОр!'
677        q = u"select '%s'" % result
678        if not unicode_strings:
679            result = result.encode('utf8')
680        # pass the query as unicode
681        try:
682            v = self.c.query(q).getresult()[0][0]
683        except(pg.DataError, pg.NotSupportedError):
684            self.skipTest("database does not support utf8")
685        self.assertIsInstance(v, str)
686        self.assertEqual(v, result)
687        q = q.encode('utf8')
688        # pass the query as bytes
689        v = self.c.query(q).getresult()[0][0]
690        self.assertIsInstance(v, str)
691        self.assertEqual(v, result)
692
693    def testDictresultUtf8(self):
694        result = u'Hello, wörld & ЌОр!'
695        q = u"select '%s' as greeting" % result
696        if not unicode_strings:
697            result = result.encode('utf8')
698        try:
699            v = self.c.query(q).dictresult()[0]['greeting']
700        except (pg.DataError, pg.NotSupportedError):
701            self.skipTest("database does not support utf8")
702        self.assertIsInstance(v, str)
703        self.assertEqual(v, result)
704        q = q.encode('utf8')
705        v = self.c.query(q).dictresult()[0]['greeting']
706        self.assertIsInstance(v, str)
707        self.assertEqual(v, result)
708
709    def testDictresultLatin1(self):
710        try:
711            self.c.query('set client_encoding=latin1')
712        except (pg.DataError, pg.NotSupportedError):
713            self.skipTest("database does not support latin1")
714        result = u'Hello, wörld!'
715        q = u"select '%s'" % result
716        if not unicode_strings:
717            result = result.encode('latin1')
718        v = self.c.query(q).getresult()[0][0]
719        self.assertIsInstance(v, str)
720        self.assertEqual(v, result)
721        q = q.encode('latin1')
722        v = self.c.query(q).getresult()[0][0]
723        self.assertIsInstance(v, str)
724        self.assertEqual(v, result)
725
726    def testDictresultLatin1(self):
727        try:
728            self.c.query('set client_encoding=latin1')
729        except (pg.DataError, pg.NotSupportedError):
730            self.skipTest("database does not support latin1")
731        result = u'Hello, wörld!'
732        q = u"select '%s' as greeting" % result
733        if not unicode_strings:
734            result = result.encode('latin1')
735        v = self.c.query(q).dictresult()[0]['greeting']
736        self.assertIsInstance(v, str)
737        self.assertEqual(v, result)
738        q = q.encode('latin1')
739        v = self.c.query(q).dictresult()[0]['greeting']
740        self.assertIsInstance(v, str)
741        self.assertEqual(v, result)
742
743    def testGetresultCyrillic(self):
744        try:
745            self.c.query('set client_encoding=iso_8859_5')
746        except (pg.DataError, pg.NotSupportedError):
747            self.skipTest("database does not support cyrillic")
748        result = u'Hello, ЌОр!'
749        q = u"select '%s'" % result
750        if not unicode_strings:
751            result = result.encode('cyrillic')
752        v = self.c.query(q).getresult()[0][0]
753        self.assertIsInstance(v, str)
754        self.assertEqual(v, result)
755        q = q.encode('cyrillic')
756        v = self.c.query(q).getresult()[0][0]
757        self.assertIsInstance(v, str)
758        self.assertEqual(v, result)
759
760    def testDictresultCyrillic(self):
761        try:
762            self.c.query('set client_encoding=iso_8859_5')
763        except (pg.DataError, pg.NotSupportedError):
764            self.skipTest("database does not support cyrillic")
765        result = u'Hello, ЌОр!'
766        q = u"select '%s' as greeting" % result
767        if not unicode_strings:
768            result = result.encode('cyrillic')
769        v = self.c.query(q).dictresult()[0]['greeting']
770        self.assertIsInstance(v, str)
771        self.assertEqual(v, result)
772        q = q.encode('cyrillic')
773        v = self.c.query(q).dictresult()[0]['greeting']
774        self.assertIsInstance(v, str)
775        self.assertEqual(v, result)
776
777    def testGetresultLatin9(self):
778        try:
779            self.c.query('set client_encoding=latin9')
780        except (pg.DataError, pg.NotSupportedError):
781            self.skipTest("database does not support latin9")
782        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
783        q = u"select '%s'" % result
784        if not unicode_strings:
785            result = result.encode('latin9')
786        v = self.c.query(q).getresult()[0][0]
787        self.assertIsInstance(v, str)
788        self.assertEqual(v, result)
789        q = q.encode('latin9')
790        v = self.c.query(q).getresult()[0][0]
791        self.assertIsInstance(v, str)
792        self.assertEqual(v, result)
793
794    def testDictresultLatin9(self):
795        try:
796            self.c.query('set client_encoding=latin9')
797        except (pg.DataError, pg.NotSupportedError):
798            self.skipTest("database does not support latin9")
799        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
800        q = u"select '%s' as menu" % result
801        if not unicode_strings:
802            result = result.encode('latin9')
803        v = self.c.query(q).dictresult()[0]['menu']
804        self.assertIsInstance(v, str)
805        self.assertEqual(v, result)
806        q = q.encode('latin9')
807        v = self.c.query(q).dictresult()[0]['menu']
808        self.assertIsInstance(v, str)
809        self.assertEqual(v, result)
810
811
812class TestParamQueries(unittest.TestCase):
813    """Test queries with parameters via a basic pg connection."""
814
815    def setUp(self):
816        self.c = connect()
817        self.c.query('set client_encoding=utf8')
818
819    def tearDown(self):
820        self.c.close()
821
822    def testQueryWithNoneParam(self):
823        self.assertRaises(TypeError, self.c.query, "select $1", None)
824        self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None)
825        self.assertEqual(self.c.query("select $1::integer", (None,)
826            ).getresult(), [(None,)])
827        self.assertEqual(self.c.query("select $1::text", [None]
828            ).getresult(), [(None,)])
829        self.assertEqual(self.c.query("select $1::text", [[None]]
830            ).getresult(), [(None,)])
831
832    def testQueryWithBoolParams(self, bool_enabled=None):
833        query = self.c.query
834        if bool_enabled is not None:
835            bool_enabled_default = pg.get_bool()
836            pg.set_bool(bool_enabled)
837        try:
838            bool_on = bool_enabled or bool_enabled is None
839            v_false, v_true = (False, True) if bool_on else 'ft'
840            r_false, r_true = [(v_false,)], [(v_true,)]
841            self.assertEqual(query("select false").getresult(), r_false)
842            self.assertEqual(query("select true").getresult(), r_true)
843            q = "select $1::bool"
844            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
845            self.assertEqual(query(q, ('f',)).getresult(), r_false)
846            self.assertEqual(query(q, ('t',)).getresult(), r_true)
847            self.assertEqual(query(q, ('false',)).getresult(), r_false)
848            self.assertEqual(query(q, ('true',)).getresult(), r_true)
849            self.assertEqual(query(q, ('n',)).getresult(), r_false)
850            self.assertEqual(query(q, ('y',)).getresult(), r_true)
851            self.assertEqual(query(q, (0,)).getresult(), r_false)
852            self.assertEqual(query(q, (1,)).getresult(), r_true)
853            self.assertEqual(query(q, (False,)).getresult(), r_false)
854            self.assertEqual(query(q, (True,)).getresult(), r_true)
855        finally:
856            if bool_enabled is not None:
857                pg.set_bool(bool_enabled_default)
858
859    def testQueryWithBoolParamsNotDefault(self):
860        self.testQueryWithBoolParams(bool_enabled=not pg.get_bool())
861
862    def testQueryWithIntParams(self):
863        query = self.c.query
864        self.assertEqual(query("select 1+1").getresult(), [(2,)])
865        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
866        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
867        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
868        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
869        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
870            [(Decimal('2'),)])
871        self.assertEqual(query("select 1, $1::integer", (2,)
872            ).getresult(), [(1, 2)])
873        self.assertEqual(query("select 1 union select $1::integer", (2,)
874            ).getresult(), [(1,), (2,)])
875        self.assertEqual(query("select $1::integer+$2", (1, 2)
876            ).getresult(), [(3,)])
877        self.assertEqual(query("select $1::integer+$2", [1, 2]
878            ).getresult(), [(3,)])
879        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))
880            ).getresult(), [(15,)])
881
882    def testQueryWithStrParams(self):
883        query = self.c.query
884        self.assertEqual(query("select $1||', world!'", ('Hello',)
885            ).getresult(), [('Hello, world!',)])
886        self.assertEqual(query("select $1||', world!'", ['Hello']
887            ).getresult(), [('Hello, world!',)])
888        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
889            ).getresult(), [('Hello, world!',)])
890        self.assertEqual(query("select $1::text", ('Hello, world!',)
891            ).getresult(), [('Hello, world!',)])
892        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
893            ).getresult(), [('Hello', 'world')])
894        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
895            ).getresult(), [('Hello', 'world')])
896        self.assertEqual(query("select $1::text union select $2::text",
897            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
898        try:
899            query("select 'wörld'")
900        except (pg.DataError, pg.NotSupportedError):
901            self.skipTest('database does not support utf8')
902        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
903            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
904
905    def testQueryWithUnicodeParams(self):
906        query = self.c.query
907        try:
908            query('set client_encoding=utf8')
909            query("select 'wörld'").getresult()[0][0] == 'wörld'
910        except (pg.DataError, pg.NotSupportedError):
911            self.skipTest("database does not support utf8")
912        self.assertEqual(query("select $1||', '||$2||'!'",
913            ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)])
914
915    def testQueryWithUnicodeParamsLatin1(self):
916        query = self.c.query
917        try:
918            query('set client_encoding=latin1')
919            query("select 'wörld'").getresult()[0][0] == 'wörld'
920        except (pg.DataError, pg.NotSupportedError):
921            self.skipTest("database does not support latin1")
922        r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult()
923        if unicode_strings:
924            self.assertEqual(r, [('Hello, wörld!',)])
925        else:
926            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
927        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
928            ('Hello', u'ЌОр'))
929        query('set client_encoding=iso_8859_1')
930        r = query("select $1||', '||$2||'!'",
931            ('Hello', u'wörld')).getresult()
932        if unicode_strings:
933            self.assertEqual(r, [('Hello, wörld!',)])
934        else:
935            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
936        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
937            ('Hello', u'ЌОр'))
938        query('set client_encoding=sql_ascii')
939        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
940            ('Hello', u'wörld'))
941
942    def testQueryWithUnicodeParamsCyrillic(self):
943        query = self.c.query
944        try:
945            query('set client_encoding=iso_8859_5')
946            query("select 'ЌОр'").getresult()[0][0] == 'ЌОр'
947        except (pg.DataError, pg.NotSupportedError):
948            self.skipTest("database does not support cyrillic")
949        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
950            ('Hello', u'wörld'))
951        r = query("select $1||', '||$2||'!'",
952            ('Hello', u'ЌОр')).getresult()
953        if unicode_strings:
954            self.assertEqual(r, [('Hello, ЌОр!',)])
955        else:
956            self.assertEqual(r, [(u'Hello, ЌОр!'.encode('cyrillic'),)])
957        query('set client_encoding=sql_ascii')
958        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
959            ('Hello', u'ЌОр!'))
960
961    def testQueryWithMixedParams(self):
962        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
963            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
964        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
965            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
966
967    def testQueryWithDuplicateParams(self):
968        self.assertRaises(pg.ProgrammingError,
969            self.c.query, "select $1+$1", (1,))
970        self.assertRaises(pg.ProgrammingError,
971            self.c.query, "select $1+$1", (1, 2))
972
973    def testQueryWithZeroParams(self):
974        self.assertEqual(self.c.query("select 1+1", []
975            ).getresult(), [(2,)])
976
977    def testQueryWithGarbage(self):
978        garbage = r"'\{}+()-#[]oo324"
979        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
980            ).dictresult(), [{'garbage': garbage}])
981
982
983class TestPreparedQueries(unittest.TestCase):
984    """Test prepared queries via a basic pg connection."""
985
986    def setUp(self):
987        self.c = connect()
988        self.c.query('set client_encoding=utf8')
989
990    def tearDown(self):
991        self.c.close()
992
993    def testEmptyPreparedStatement(self):
994        self.c.prepare('', '')
995        self.assertRaises(ValueError, self.c.query_prepared, '')
996
997    def testInvalidPreparedStatement(self):
998        self.assertRaises(pg.ProgrammingError, self.c.prepare, '', 'bad')
999
1000    def testDuplicatePreparedStatement(self):
1001        self.assertIsNone(self.c.prepare('q', 'select 1'))
1002        self.assertRaises(pg.ProgrammingError, self.c.prepare, 'q', 'select 2')
1003
1004    def testNonExistentPreparedStatement(self):
1005        self.assertRaises(pg.OperationalError,
1006            self.c.query_prepared, 'does-not-exist')
1007
1008    def testUnnamedQueryWithoutParams(self):
1009        self.assertIsNone(self.c.prepare('', "select 'anon'"))
1010        self.assertEqual(self.c.query_prepared('').getresult(), [('anon',)])
1011        self.assertEqual(self.c.query_prepared('').getresult(), [('anon',)])
1012
1013    def testNamedQueryWithoutParams(self):
1014        self.assertIsNone(self.c.prepare('hello', "select 'world'"))
1015        self.assertEqual(self.c.query_prepared('hello').getresult(),
1016            [('world',)])
1017
1018    def testMultipleNamedQueriesWithoutParams(self):
1019        self.assertIsNone(self.c.prepare('query17', "select 17"))
1020        self.assertIsNone(self.c.prepare('query42', "select 42"))
1021        self.assertEqual(self.c.query_prepared('query17').getresult(), [(17,)])
1022        self.assertEqual(self.c.query_prepared('query42').getresult(), [(42,)])
1023
1024    def testUnnamedQueryWithParams(self):
1025        self.assertIsNone(self.c.prepare('', "select $1 || ', ' || $2"))
1026        self.assertEqual(
1027            self.c.query_prepared('', ['hello', 'world']).getresult(),
1028            [('hello, world',)])
1029        self.assertIsNone(self.c.prepare('', "select 1+ $1 + $2 + $3"))
1030        self.assertEqual(
1031            self.c.query_prepared('', [17, -5, 29]).getresult(), [(42,)])
1032
1033    def testMultipleNamedQueriesWithParams(self):
1034        self.assertIsNone(self.c.prepare('q1', "select $1 || '!'"))
1035        self.assertIsNone(self.c.prepare('q2', "select $1 || '-' || $2"))
1036        self.assertEqual(self.c.query_prepared('q1', ['hello']).getresult(),
1037            [('hello!',)])
1038        self.assertEqual(self.c.query_prepared('q2', ['he', 'lo']).getresult(),
1039            [('he-lo',)])
1040
1041    def testDescribeNonExistentQuery(self):
1042        self.assertRaises(pg.OperationalError,
1043            self.c.describe_prepared, 'does-not-exist')
1044
1045    def testDescribeUnnamedQuery(self):
1046        self.c.prepare('', "select 1::int, 'a'::char")
1047        r = self.c.describe_prepared('')
1048        self.assertEqual(r.listfields(), ('int4', 'bpchar'))
1049
1050    def testDescribeNamedQuery(self):
1051        self.c.prepare('myquery', "select 1 as first, 2 as second")
1052        r = self.c.describe_prepared('myquery')
1053        self.assertEqual(r.listfields(), ('first', 'second'))
1054
1055    def testDescribeMultipleNamedQueries(self):
1056        self.c.prepare('query1', "select 1::int")
1057        self.c.prepare('query2', "select 1::int, 2::int")
1058        r = self.c.describe_prepared('query1')
1059        self.assertEqual(r.listfields(), ('int4',))
1060        r = self.c.describe_prepared('query2')
1061        self.assertEqual(r.listfields(), ('int4', 'int4'))
1062
1063
1064class TestQueryResultTypes(unittest.TestCase):
1065    """Test proper result types via a basic pg connection."""
1066
1067    def setUp(self):
1068        self.c = connect()
1069        self.c.query('set client_encoding=utf8')
1070        self.c.query("set datestyle='ISO,YMD'")
1071        self.c.query("set timezone='UTC'")
1072
1073    def tearDown(self):
1074        self.c.close()
1075
1076    def assert_proper_cast(self, value, pgtype, pytype):
1077        q = 'select $1::%s' % (pgtype,)
1078        try:
1079            r = self.c.query(q, (value,)).getresult()[0][0]
1080        except pg.ProgrammingError:
1081            if pgtype in ('json', 'jsonb'):
1082                self.skipTest('database does not support json')
1083        self.assertIsInstance(r, pytype)
1084        if isinstance(value, str):
1085            if not value or ' ' in value or '{' in value:
1086                value = '"%s"' % value
1087        value = '{%s}' % value
1088        r = self.c.query(q + '[]', (value,)).getresult()[0][0]
1089        if pgtype.startswith(('date', 'time', 'interval')):
1090            # arrays of these are casted by the DB wrapper only
1091            self.assertEqual(r, value)
1092        else:
1093            self.assertIsInstance(r, list)
1094            self.assertEqual(len(r), 1)
1095            self.assertIsInstance(r[0], pytype)
1096
1097    def testInt(self):
1098        self.assert_proper_cast(0, 'int', int)
1099        self.assert_proper_cast(0, 'smallint', int)
1100        self.assert_proper_cast(0, 'oid', int)
1101        self.assert_proper_cast(0, 'cid', int)
1102        self.assert_proper_cast(0, 'xid', int)
1103
1104    def testLong(self):
1105        self.assert_proper_cast(0, 'bigint', long)
1106
1107    def testFloat(self):
1108        self.assert_proper_cast(0, 'float', float)
1109        self.assert_proper_cast(0, 'real', float)
1110        self.assert_proper_cast(0, 'double', float)
1111        self.assert_proper_cast(0, 'double precision', float)
1112        self.assert_proper_cast('infinity', 'float', float)
1113
1114    def testFloat(self):
1115        decimal = pg.get_decimal()
1116        self.assert_proper_cast(decimal(0), 'numeric', decimal)
1117        self.assert_proper_cast(decimal(0), 'decimal', decimal)
1118
1119    def testMoney(self):
1120        decimal = pg.get_decimal()
1121        self.assert_proper_cast(decimal('0'), 'money', decimal)
1122
1123    def testBool(self):
1124        bool_type = bool if pg.get_bool() else str
1125        self.assert_proper_cast('f', 'bool', bool_type)
1126
1127    def testDate(self):
1128        self.assert_proper_cast('1956-01-31', 'date', str)
1129        self.assert_proper_cast('10:20:30', 'interval', str)
1130        self.assert_proper_cast('08:42:15', 'time', str)
1131        self.assert_proper_cast('08:42:15+00', 'timetz', str)
1132        self.assert_proper_cast('1956-01-31 08:42:15', 'timestamp', str)
1133        self.assert_proper_cast('1956-01-31 08:42:15+00', 'timestamptz', str)
1134
1135    def testText(self):
1136        self.assert_proper_cast('', 'text', str)
1137        self.assert_proper_cast('', 'char', str)
1138        self.assert_proper_cast('', 'bpchar', str)
1139        self.assert_proper_cast('', 'varchar', str)
1140
1141    def testBytea(self):
1142        self.assert_proper_cast('', 'bytea', bytes)
1143
1144    def testJson(self):
1145        self.assert_proper_cast('{}', 'json', dict)
1146
1147
1148class TestQueryIterator(unittest.TestCase):
1149    """Test the query operating as an iterator."""
1150
1151    def setUp(self):
1152        self.c = connect()
1153
1154    def tearDown(self):
1155        self.c.close()
1156
1157    def testLen(self):
1158        r = self.c.query("select generate_series(3,7)")
1159        self.assertEqual(len(r), 5)
1160
1161    def testGetItem(self):
1162        r = self.c.query("select generate_series(7,9)")
1163        self.assertEqual(r[0], (7,))
1164        self.assertEqual(r[1], (8,))
1165        self.assertEqual(r[2], (9,))
1166
1167    def testGetItemWithNegativeIndex(self):
1168        r = self.c.query("select generate_series(7,9)")
1169        self.assertEqual(r[-1], (9,))
1170        self.assertEqual(r[-2], (8,))
1171        self.assertEqual(r[-3], (7,))
1172
1173    def testGetItemOutOfRange(self):
1174        r = self.c.query("select generate_series(7,9)")
1175        self.assertRaises(IndexError, r.__getitem__, 3)
1176
1177    def testIterate(self):
1178        r = self.c.query("select generate_series(3,5)")
1179        self.assertNotIsInstance(r, (list, tuple))
1180        self.assertIsInstance(r, Iterable)
1181        self.assertEqual(list(r), [(3,), (4,), (5,)])
1182        self.assertIsInstance(r[1], tuple)
1183
1184    def testIterateTwice(self):
1185        r = self.c.query("select generate_series(3,5)")
1186        for i in range(2):
1187            self.assertEqual(list(r), [(3,), (4,), (5,)])
1188
1189    def testIterateTwoColumns(self):
1190        r = self.c.query("select 1,2 union select 3,4")
1191        self.assertIsInstance(r, Iterable)
1192        self.assertEqual(list(r), [(1, 2), (3, 4)])
1193
1194    def testNext(self):
1195        r = self.c.query("select generate_series(7,9)")
1196        self.assertEqual(next(r), (7,))
1197        self.assertEqual(next(r), (8,))
1198        self.assertEqual(next(r), (9,))
1199        self.assertRaises(StopIteration, next, r)
1200
1201    def testContains(self):
1202        r = self.c.query("select generate_series(7,9)")
1203        self.assertIn((8,), r)
1204        self.assertNotIn((5,), r)
1205
1206    def testDictIterate(self):
1207        r = self.c.query("select generate_series(3,5) as n").dictiter()
1208        self.assertNotIsInstance(r, (list, tuple))
1209        self.assertIsInstance(r, Iterable)
1210        r = list(r)
1211        self.assertEqual(r, [dict(n=3), dict(n=4), dict(n=5)])
1212        self.assertIsInstance(r[1], dict)
1213
1214    def testDictIterateTwoColumns(self):
1215        r = self.c.query("select 1 as one, 2 as two"
1216            " union select 3 as one, 4 as two").dictiter()
1217        self.assertIsInstance(r, Iterable)
1218        r = list(r)
1219        self.assertEqual(r, [dict(one=1, two=2), dict(one=3, two=4)])
1220
1221    def testDictNext(self):
1222        r = self.c.query("select generate_series(7,9) as n").dictiter()
1223        self.assertEqual(next(r), dict(n=7))
1224        self.assertEqual(next(r), dict(n=8))
1225        self.assertEqual(next(r), dict(n=9))
1226        self.assertRaises(StopIteration, next, r)
1227
1228    def testDictContains(self):
1229        r = self.c.query("select generate_series(7,9) as n").dictiter()
1230        self.assertIn(dict(n=8), r)
1231        self.assertNotIn(dict(n=5), r)
1232
1233    def testNamedIterate(self):
1234        r = self.c.query("select generate_series(3,5) as number").namediter()
1235        self.assertNotIsInstance(r, (list, tuple))
1236        self.assertIsInstance(r, Iterable)
1237        r = list(r)
1238        self.assertEqual(r, [(3,), (4,), (5,)])
1239        self.assertIsInstance(r[1], tuple)
1240        self.assertEqual(r[1]._fields, ('number',))
1241        self.assertEqual(r[1].number, 4)
1242
1243    def testNamedIterateTwoColumns(self):
1244        r = self.c.query("select 1 as one, 2 as two"
1245            " union select 3 as one, 4 as two").namediter()
1246        self.assertIsInstance(r, Iterable)
1247        r = list(r)
1248        self.assertEqual(r, [(1, 2), (3, 4)])
1249        self.assertEqual(r[0]._fields, ('one', 'two'))
1250        self.assertEqual(r[0].one, 1)
1251        self.assertEqual(r[1]._fields, ('one', 'two'))
1252        self.assertEqual(r[1].two, 4)
1253
1254    def testNamedNext(self):
1255        r = self.c.query("select generate_series(7,9) as number").namediter()
1256        self.assertEqual(next(r), (7,))
1257        self.assertEqual(next(r), (8,))
1258        n = next(r)
1259        self.assertEqual(n._fields, ('number',))
1260        self.assertEqual(n.number, 9)
1261        self.assertRaises(StopIteration, next, r)
1262
1263    def testNamedContains(self):
1264        r = self.c.query("select generate_series(7,9)").namediter()
1265        self.assertIn((8,), r)
1266        self.assertNotIn((5,), r)
1267
1268    def testScalarIterate(self):
1269        r = self.c.query("select generate_series(3,5)").scalariter()
1270        self.assertNotIsInstance(r, (list, tuple))
1271        self.assertIsInstance(r, Iterable)
1272        r = list(r)
1273        self.assertEqual(r, [3, 4, 5])
1274        self.assertIsInstance(r[1], int)
1275
1276    def testScalarIterateTwoColumns(self):
1277        r = self.c.query("select 1, 2 union select 3, 4").scalariter()
1278        self.assertIsInstance(r, Iterable)
1279        r = list(r)
1280        self.assertEqual(r, [1, 3])
1281
1282    def testScalarNext(self):
1283        r = self.c.query("select generate_series(7,9)").scalariter()
1284        self.assertEqual(next(r), 7)
1285        self.assertEqual(next(r), 8)
1286        self.assertEqual(next(r), 9)
1287        self.assertRaises(StopIteration, next, r)
1288
1289    def testScalarContains(self):
1290        r = self.c.query("select generate_series(7,9)").scalariter()
1291        self.assertIn(8, r)
1292        self.assertNotIn(5, r)
1293
1294
1295class TestQueryOneSingleScalar(unittest.TestCase):
1296    """Test the query methods for getting single rows and columns."""
1297
1298    def setUp(self):
1299        self.c = connect()
1300
1301    def tearDown(self):
1302        self.c.close()
1303
1304    def testOneWithEmptyQuery(self):
1305        q = self.c.query("select 0 where false")
1306        self.assertIsNone(q.one())
1307
1308    def testOneWithSingleRow(self):
1309        q = self.c.query("select 1, 2")
1310        r = q.one()
1311        self.assertIsInstance(r, tuple)
1312        self.assertEqual(r, (1, 2))
1313        self.assertEqual(q.one(), None)
1314
1315    def testOneWithTwoRows(self):
1316        q = self.c.query("select 1, 2 union select 3, 4")
1317        self.assertEqual(q.one(), (1, 2))
1318        self.assertEqual(q.one(), (3, 4))
1319        self.assertEqual(q.one(), None)
1320
1321    def testOneDictWithEmptyQuery(self):
1322        q = self.c.query("select 0 where false")
1323        self.assertIsNone(q.onedict())
1324
1325    def testOneDictWithSingleRow(self):
1326        q = self.c.query("select 1 as one, 2 as two")
1327        r = q.onedict()
1328        self.assertIsInstance(r, dict)
1329        self.assertEqual(r, dict(one=1, two=2))
1330        self.assertEqual(q.onedict(), None)
1331
1332    def testOneDictWithTwoRows(self):
1333        q = self.c.query(
1334            "select 1 as one, 2 as two union select 3 as one, 4 as two")
1335        self.assertEqual(q.onedict(), dict(one=1, two=2))
1336        self.assertEqual(q.onedict(), dict(one=3, two=4))
1337        self.assertEqual(q.onedict(), None)
1338
1339    def testOneNamedWithEmptyQuery(self):
1340        q = self.c.query("select 0 where false")
1341        self.assertIsNone(q.onenamed())
1342
1343    def testOneNamedWithSingleRow(self):
1344        q = self.c.query("select 1 as one, 2 as two")
1345        r = q.onenamed()
1346        self.assertEqual(r._fields, ('one', 'two'))
1347        self.assertEqual(r.one, 1)
1348        self.assertEqual(r.two, 2)
1349        self.assertEqual(r, (1, 2))
1350        self.assertEqual(q.onenamed(), None)
1351
1352    def testOneNamedWithTwoRows(self):
1353        q = self.c.query(
1354            "select 1 as one, 2 as two union select 3 as one, 4 as two")
1355        r = q.onenamed()
1356        self.assertEqual(r._fields, ('one', 'two'))
1357        self.assertEqual(r.one, 1)
1358        self.assertEqual(r.two, 2)
1359        self.assertEqual(r, (1, 2))
1360        r = q.onenamed()
1361        self.assertEqual(r._fields, ('one', 'two'))
1362        self.assertEqual(r.one, 3)
1363        self.assertEqual(r.two, 4)
1364        self.assertEqual(r, (3, 4))
1365        self.assertEqual(q.onenamed(), None)
1366
1367    def testOneScalarWithEmptyQuery(self):
1368        q = self.c.query("select 0 where false")
1369        self.assertIsNone(q.onescalar())
1370
1371    def testOneScalarWithSingleRow(self):
1372        q = self.c.query("select 1, 2")
1373        r = q.onescalar()
1374        self.assertIsInstance(r, int)
1375        self.assertEqual(r, 1)
1376        self.assertEqual(q.onescalar(), None)
1377
1378    def testOneScalarWithTwoRows(self):
1379        q = self.c.query("select 1, 2 union select 3, 4")
1380        self.assertEqual(q.onescalar(), 1)
1381        self.assertEqual(q.onescalar(), 3)
1382        self.assertEqual(q.onescalar(), None)
1383
1384    def testSingleWithEmptyQuery(self):
1385        q = self.c.query("select 0 where false")
1386        try:
1387            q.single()
1388        except pg.InvalidResultError as e:
1389            r = e
1390        else:
1391            r = None
1392        self.assertIsInstance(r, pg.NoResultError)
1393        self.assertEqual(str(r), 'No result found')
1394
1395    def testSingleWithSingleRow(self):
1396        q = self.c.query("select 1, 2")
1397        r = q.single()
1398        self.assertIsInstance(r, tuple)
1399        self.assertEqual(r, (1, 2))
1400        r = q.single()
1401        self.assertIsInstance(r, tuple)
1402        self.assertEqual(r, (1, 2))
1403
1404    def testSingleWithTwoRows(self):
1405        q = self.c.query("select 1, 2 union select 3, 4")
1406        try:
1407            q.single()
1408        except pg.InvalidResultError as e:
1409            r = e
1410        else:
1411            r = None
1412        self.assertIsInstance(r, pg.MultipleResultsError)
1413        self.assertEqual(str(r), 'Multiple results found')
1414
1415    def testSingleDictWithEmptyQuery(self):
1416        q = self.c.query("select 0 where false")
1417        try:
1418            q.singledict()
1419        except pg.InvalidResultError as e:
1420            r = e
1421        else:
1422            r = None
1423        self.assertIsInstance(r, pg.NoResultError)
1424        self.assertEqual(str(r), 'No result found')
1425
1426    def testSingleDictWithSingleRow(self):
1427        q = self.c.query("select 1 as one, 2 as two")
1428        r = q.singledict()
1429        self.assertIsInstance(r, dict)
1430        self.assertEqual(r, dict(one=1, two=2))
1431        r = q.singledict()
1432        self.assertIsInstance(r, dict)
1433        self.assertEqual(r, dict(one=1, two=2))
1434
1435    def testSingleDictWithTwoRows(self):
1436        q = self.c.query("select 1, 2 union select 3, 4")
1437        try:
1438            q.singledict()
1439        except pg.InvalidResultError as e:
1440            r = e
1441        else:
1442            r = None
1443        self.assertIsInstance(r, pg.MultipleResultsError)
1444        self.assertEqual(str(r), 'Multiple results found')
1445
1446    def testSingleNamedWithEmptyQuery(self):
1447        q = self.c.query("select 0 where false")
1448        try:
1449            q.singlenamed()
1450        except pg.InvalidResultError as e:
1451            r = e
1452        else:
1453            r = None
1454        self.assertIsInstance(r, pg.NoResultError)
1455        self.assertEqual(str(r), 'No result found')
1456
1457    def testSingleNamedWithSingleRow(self):
1458        q = self.c.query("select 1 as one, 2 as two")
1459        r = q.singlenamed()
1460        self.assertEqual(r._fields, ('one', 'two'))
1461        self.assertEqual(r.one, 1)
1462        self.assertEqual(r.two, 2)
1463        self.assertEqual(r, (1, 2))
1464        r = q.singlenamed()
1465        self.assertEqual(r._fields, ('one', 'two'))
1466        self.assertEqual(r.one, 1)
1467        self.assertEqual(r.two, 2)
1468        self.assertEqual(r, (1, 2))
1469
1470    def testSingleNamedWithTwoRows(self):
1471        q = self.c.query("select 1, 2 union select 3, 4")
1472        try:
1473            q.singlenamed()
1474        except pg.InvalidResultError as e:
1475            r = e
1476        else:
1477            r = None
1478        self.assertIsInstance(r, pg.MultipleResultsError)
1479        self.assertEqual(str(r), 'Multiple results found')
1480
1481    def testSingleScalarWithEmptyQuery(self):
1482        q = self.c.query("select 0 where false")
1483        try:
1484            q.singlescalar()
1485        except pg.InvalidResultError as e:
1486            r = e
1487        else:
1488            r = None
1489        self.assertIsInstance(r, pg.NoResultError)
1490        self.assertEqual(str(r), 'No result found')
1491
1492    def testSingleScalarWithSingleRow(self):
1493        q = self.c.query("select 1, 2")
1494        r = q.singlescalar()
1495        self.assertIsInstance(r, int)
1496        self.assertEqual(r, 1)
1497        r = q.singlescalar()
1498        self.assertIsInstance(r, int)
1499        self.assertEqual(r, 1)
1500
1501    def testSingleWithTwoRows(self):
1502        q = self.c.query("select 1, 2 union select 3, 4")
1503        try:
1504            q.singlescalar()
1505        except pg.InvalidResultError as e:
1506            r = e
1507        else:
1508            r = None
1509        self.assertIsInstance(r, pg.MultipleResultsError)
1510        self.assertEqual(str(r), 'Multiple results found')
1511
1512    def testScalarResult(self):
1513        q = self.c.query("select 1, 2 union select 3, 4")
1514        r = q.scalarresult()
1515        self.assertIsInstance(r, list)
1516        self.assertEqual(r, [1, 3])
1517
1518    def testScalarIter(self):
1519        q = self.c.query("select 1, 2 union select 3, 4")
1520        r = q.scalariter()
1521        self.assertNotIsInstance(r, (list, tuple))
1522        self.assertIsInstance(r, Iterable)
1523        r = list(r)
1524        self.assertEqual(r, [1, 3])
1525
1526
1527class TestInserttable(unittest.TestCase):
1528    """Test inserttable method."""
1529
1530    cls_set_up = False
1531
1532    @classmethod
1533    def setUpClass(cls):
1534        c = connect()
1535        c.query("drop table if exists test cascade")
1536        c.query("create table test ("
1537            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
1538            "d numeric, f4 real, f8 double precision, m money,"
1539            "c char(1), v4 varchar(4), c4 char(4), t text)")
1540        # Check whether the test database uses SQL_ASCII - this means
1541        # that it does not consider encoding when calculating lengths.
1542        c.query("set client_encoding=utf8")
1543        try:
1544            c.query("select 'À'")
1545        except (pg.DataError, pg.NotSupportedError):
1546            cls.has_encoding = False
1547        else:
1548            cls.has_encoding = c.query(
1549                "select length('À') - length('a')").getresult()[0][0] == 0
1550        c.close()
1551        cls.cls_set_up = True
1552
1553    @classmethod
1554    def tearDownClass(cls):
1555        c = connect()
1556        c.query("drop table test cascade")
1557        c.close()
1558
1559    def setUp(self):
1560        self.assertTrue(self.cls_set_up)
1561        self.c = connect()
1562        self.c.query("set client_encoding=utf8")
1563        self.c.query("set datestyle='ISO,YMD'")
1564        self.c.query("set lc_monetary='C'")
1565
1566    def tearDown(self):
1567        self.c.query("truncate table test")
1568        self.c.close()
1569
1570    data = [
1571        (-1, -1, long(-1), True, '1492-10-12', '08:30:00',
1572            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
1573        (0, 0, long(0), False, '1607-04-14', '09:00:00',
1574            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
1575        (1, 1, long(1), True, '1801-03-04', '03:45:00',
1576            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
1577        (2, 2, long(2), False, '1903-12-17', '11:22:00',
1578            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
1579
1580    @classmethod
1581    def db_len(cls, s, encoding):
1582        if cls.has_encoding:
1583            s = s if isinstance(s, unicode) else s.decode(encoding)
1584        else:
1585            s = s.encode(encoding) if isinstance(s, unicode) else s
1586        return len(s)
1587
1588    def get_back(self, encoding='utf-8'):
1589        """Convert boolean and decimal values back."""
1590        data = []
1591        for row in self.c.query("select * from test order by 1").getresult():
1592            self.assertIsInstance(row, tuple)
1593            row = list(row)
1594            if row[0] is not None:  # smallint
1595                self.assertIsInstance(row[0], int)
1596            if row[1] is not None:  # integer
1597                self.assertIsInstance(row[1], int)
1598            if row[2] is not None:  # bigint
1599                self.assertIsInstance(row[2], long)
1600            if row[3] is not None:  # boolean
1601                self.assertIsInstance(row[3], bool)
1602            if row[4] is not None:  # date
1603                self.assertIsInstance(row[4], str)
1604                self.assertTrue(row[4].replace('-', '').isdigit())
1605            if row[5] is not None:  # time
1606                self.assertIsInstance(row[5], str)
1607                self.assertTrue(row[5].replace(':', '').isdigit())
1608            if row[6] is not None:  # numeric
1609                self.assertIsInstance(row[6], Decimal)
1610                row[6] = float(row[6])
1611            if row[7] is not None:  # real
1612                self.assertIsInstance(row[7], float)
1613            if row[8] is not None:  # double precision
1614                self.assertIsInstance(row[8], float)
1615                row[8] = float(row[8])
1616            if row[9] is not None:  # money
1617                self.assertIsInstance(row[9], Decimal)
1618                row[9] = str(float(row[9]))
1619            if row[10] is not None:  # char(1)
1620                self.assertIsInstance(row[10], str)
1621                self.assertEqual(self.db_len(row[10], encoding), 1)
1622            if row[11] is not None:  # varchar(4)
1623                self.assertIsInstance(row[11], str)
1624                self.assertLessEqual(self.db_len(row[11], encoding), 4)
1625            if row[12] is not None:  # char(4)
1626                self.assertIsInstance(row[12], str)
1627                self.assertEqual(self.db_len(row[12], encoding), 4)
1628                row[12] = row[12].rstrip()
1629            if row[13] is not None:  # text
1630                self.assertIsInstance(row[13], str)
1631            row = tuple(row)
1632            data.append(row)
1633        return data
1634
1635    def testInserttable1Row(self):
1636        data = self.data[2:3]
1637        self.c.inserttable('test', data)
1638        self.assertEqual(self.get_back(), data)
1639
1640    def testInserttable4Rows(self):
1641        data = self.data
1642        self.c.inserttable('test', data)
1643        self.assertEqual(self.get_back(), data)
1644
1645    def testInserttableFromTupleOfLists(self):
1646        data = tuple(list(row) for row in self.data)
1647        self.c.inserttable('test', data)
1648        self.assertEqual(self.get_back(), self.data)
1649
1650    def testInserttableFromSetofTuples(self):
1651        data = set(row for row in self.data)
1652        try:
1653            self.c.inserttable('test', data)
1654        except TypeError as e:
1655            r = str(e)
1656        else:
1657            r = 'this is fine'
1658        self.assertIn('list or a tuple as second argument', r)
1659
1660    def testInserttableFromListOfSets(self):
1661        data = [set(row) for row in self.data]
1662        try:
1663            self.c.inserttable('test', data)
1664        except TypeError as e:
1665            r = str(e)
1666        else:
1667            r = 'this is fine'
1668        self.assertIn('second argument must contain a tuple or a list', r)
1669
1670    def testInserttableMultipleRows(self):
1671        num_rows = 100
1672        data = self.data[2:3] * num_rows
1673        self.c.inserttable('test', data)
1674        r = self.c.query("select count(*) from test").getresult()[0][0]
1675        self.assertEqual(r, num_rows)
1676
1677    def testInserttableMultipleCalls(self):
1678        num_rows = 10
1679        data = self.data[2:3]
1680        for _i in range(num_rows):
1681            self.c.inserttable('test', data)
1682        r = self.c.query("select count(*) from test").getresult()[0][0]
1683        self.assertEqual(r, num_rows)
1684
1685    def testInserttableNullValues(self):
1686        data = [(None,) * 14] * 100
1687        self.c.inserttable('test', data)
1688        self.assertEqual(self.get_back(), data)
1689
1690    def testInserttableMaxValues(self):
1691        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
1692            True, '2999-12-31', '11:59:59', 1e99,
1693            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
1694            "1", "1234", "1234", "1234" * 100)]
1695        self.c.inserttable('test', data)
1696        self.assertEqual(self.get_back(), data)
1697
1698    def testInserttableByteValues(self):
1699        try:
1700            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1701        except pg.DataError:
1702            self.skipTest("database does not support utf8")
1703        # non-ascii chars do not fit in char(1) when there is no encoding
1704        c = u'€' if self.has_encoding else u'$'
1705        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1706            0.0, 0.0, 0.0, u'0.0',
1707            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1708        row_bytes = tuple(s.encode('utf-8')
1709            if isinstance(s, unicode) else s for s in row_unicode)
1710        data = [row_bytes] * 2
1711        self.c.inserttable('test', data)
1712        if unicode_strings:
1713            data = [row_unicode] * 2
1714        self.assertEqual(self.get_back(), data)
1715
1716    def testInserttableUnicodeUtf8(self):
1717        try:
1718            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1719        except pg.DataError:
1720            self.skipTest("database does not support utf8")
1721        # non-ascii chars do not fit in char(1) when there is no encoding
1722        c = u'€' if self.has_encoding else u'$'
1723        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1724            0.0, 0.0, 0.0, u'0.0',
1725            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1726        data = [row_unicode] * 2
1727        self.c.inserttable('test', data)
1728        if not unicode_strings:
1729            row_bytes = tuple(s.encode('utf-8')
1730                if isinstance(s, unicode) else s for s in row_unicode)
1731            data = [row_bytes] * 2
1732        self.assertEqual(self.get_back(), data)
1733
1734    def testInserttableUnicodeLatin1(self):
1735        try:
1736            self.c.query("set client_encoding=latin1")
1737            self.c.query("select 'Â¥'")
1738        except (pg.DataError, pg.NotSupportedError):
1739            self.skipTest("database does not support latin1")
1740        # non-ascii chars do not fit in char(1) when there is no encoding
1741        c = u'€' if self.has_encoding else u'$'
1742        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1743            0.0, 0.0, 0.0, u'0.0',
1744            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1745        data = [row_unicode]
1746        # cannot encode € sign with latin1 encoding
1747        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1748        row_unicode = tuple(s.replace(u'€', u'Â¥')
1749            if isinstance(s, unicode) else s for s in row_unicode)
1750        data = [row_unicode] * 2
1751        self.c.inserttable('test', data)
1752        if not unicode_strings:
1753            row_bytes = tuple(s.encode('latin1')
1754                if isinstance(s, unicode) else s for s in row_unicode)
1755            data = [row_bytes] * 2
1756        self.assertEqual(self.get_back('latin1'), data)
1757
1758    def testInserttableUnicodeLatin9(self):
1759        try:
1760            self.c.query("set client_encoding=latin9")
1761            self.c.query("select '€'")
1762        except (pg.DataError, pg.NotSupportedError):
1763            self.skipTest("database does not support latin9")
1764            return
1765        # non-ascii chars do not fit in char(1) when there is no encoding
1766        c = u'€' if self.has_encoding else u'$'
1767        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1768            0.0, 0.0, 0.0, u'0.0',
1769            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1770        data = [row_unicode] * 2
1771        self.c.inserttable('test', data)
1772        if not unicode_strings:
1773            row_bytes = tuple(s.encode('latin9')
1774                if isinstance(s, unicode) else s for s in row_unicode)
1775            data = [row_bytes] * 2
1776        self.assertEqual(self.get_back('latin9'), data)
1777
1778    def testInserttableNoEncoding(self):
1779        self.c.query("set client_encoding=sql_ascii")
1780        # non-ascii chars do not fit in char(1) when there is no encoding
1781        c = u'€' if self.has_encoding else u'$'
1782        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1783            0.0, 0.0, 0.0, u'0.0',
1784            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1785        data = [row_unicode]
1786        # cannot encode non-ascii unicode without a specific encoding
1787        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1788
1789
1790class TestDirectSocketAccess(unittest.TestCase):
1791    """Test copy command with direct socket access."""
1792
1793    cls_set_up = False
1794
1795    @classmethod
1796    def setUpClass(cls):
1797        c = connect()
1798        c.query("drop table if exists test cascade")
1799        c.query("create table test (i int, v varchar(16))")
1800        c.close()
1801        cls.cls_set_up = True
1802
1803    @classmethod
1804    def tearDownClass(cls):
1805        c = connect()
1806        c.query("drop table test cascade")
1807        c.close()
1808
1809    def setUp(self):
1810        self.assertTrue(self.cls_set_up)
1811        self.c = connect()
1812        self.c.query("set client_encoding=utf8")
1813
1814    def tearDown(self):
1815        self.c.query("truncate table test")
1816        self.c.close()
1817
1818    def testPutline(self):
1819        putline = self.c.putline
1820        query = self.c.query
1821        data = list(enumerate("apple pear plum cherry banana".split()))
1822        query("copy test from stdin")
1823        try:
1824            for i, v in data:
1825                putline("%d\t%s\n" % (i, v))
1826            putline("\\.\n")
1827        finally:
1828            self.c.endcopy()
1829        r = query("select * from test").getresult()
1830        self.assertEqual(r, data)
1831
1832    def testPutlineBytesAndUnicode(self):
1833        putline = self.c.putline
1834        query = self.c.query
1835        try:
1836            query("select 'kÀse+wÃŒrstel'")
1837        except (pg.DataError, pg.NotSupportedError):
1838            self.skipTest('database does not support utf8')
1839        query("copy test from stdin")
1840        try:
1841            putline(u"47\tkÀse\n".encode('utf8'))
1842            putline("35\twÃŒrstel\n")
1843            putline(b"\\.\n")
1844        finally:
1845            self.c.endcopy()
1846        r = query("select * from test").getresult()
1847        self.assertEqual(r, [(47, 'kÀse'), (35, 'wÃŒrstel')])
1848
1849    def testGetline(self):
1850        getline = self.c.getline
1851        query = self.c.query
1852        data = list(enumerate("apple banana pear plum strawberry".split()))
1853        n = len(data)
1854        self.c.inserttable('test', data)
1855        query("copy test to stdout")
1856        try:
1857            for i in range(n + 2):
1858                v = getline()
1859                if i < n:
1860                    self.assertEqual(v, '%d\t%s' % data[i])
1861                elif i == n:
1862                    self.assertEqual(v, '\\.')
1863                else:
1864                    self.assertIsNone(v)
1865        finally:
1866            try:
1867                self.c.endcopy()
1868            except IOError:
1869                pass
1870
1871    def testGetlineBytesAndUnicode(self):
1872        getline = self.c.getline
1873        query = self.c.query
1874        try:
1875            query("select 'kÀse+wÃŒrstel'")
1876        except (pg.DataError, pg.NotSupportedError):
1877            self.skipTest('database does not support utf8')
1878        data = [(54, u'kÀse'.encode('utf8')), (73, u'wÃŒrstel')]
1879        self.c.inserttable('test', data)
1880        query("copy test to stdout")
1881        try:
1882            v = getline()
1883            self.assertIsInstance(v, str)
1884            self.assertEqual(v, '54\tkÀse')
1885            v = getline()
1886            self.assertIsInstance(v, str)
1887            self.assertEqual(v, '73\twÃŒrstel')
1888            self.assertEqual(getline(), '\\.')
1889            self.assertIsNone(getline())
1890        finally:
1891            try:
1892                self.c.endcopy()
1893            except IOError:
1894                pass
1895
1896    def testParameterChecks(self):
1897        self.assertRaises(TypeError, self.c.putline)
1898        self.assertRaises(TypeError, self.c.getline, 'invalid')
1899        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
1900
1901
1902class TestNotificatons(unittest.TestCase):
1903    """Test notification support."""
1904
1905    def setUp(self):
1906        self.c = connect()
1907
1908    def tearDown(self):
1909        self.doCleanups()
1910        self.c.close()
1911
1912    def testGetNotify(self):
1913        getnotify = self.c.getnotify
1914        query = self.c.query
1915        self.assertIsNone(getnotify())
1916        query('listen test_notify')
1917        try:
1918            self.assertIsNone(self.c.getnotify())
1919            query("notify test_notify")
1920            r = getnotify()
1921            self.assertIsInstance(r, tuple)
1922            self.assertEqual(len(r), 3)
1923            self.assertIsInstance(r[0], str)
1924            self.assertIsInstance(r[1], int)
1925            self.assertIsInstance(r[2], str)
1926            self.assertEqual(r[0], 'test_notify')
1927            self.assertEqual(r[2], '')
1928            self.assertIsNone(self.c.getnotify())
1929            query("notify test_notify, 'test_payload'")
1930            r = getnotify()
1931            self.assertTrue(isinstance(r, tuple))
1932            self.assertEqual(len(r), 3)
1933            self.assertIsInstance(r[0], str)
1934            self.assertIsInstance(r[1], int)
1935            self.assertIsInstance(r[2], str)
1936            self.assertEqual(r[0], 'test_notify')
1937            self.assertEqual(r[2], 'test_payload')
1938            self.assertIsNone(getnotify())
1939        finally:
1940            query('unlisten test_notify')
1941
1942    def testGetNoticeReceiver(self):
1943        self.assertIsNone(self.c.get_notice_receiver())
1944
1945    def testSetNoticeReceiver(self):
1946        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
1947        self.assertRaises(TypeError, self.c.set_notice_receiver, 'invalid')
1948        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
1949        self.assertIsNone(self.c.set_notice_receiver(None))
1950
1951    def testSetAndGetNoticeReceiver(self):
1952        r = lambda notice: None
1953        self.assertIsNone(self.c.set_notice_receiver(r))
1954        self.assertIs(self.c.get_notice_receiver(), r)
1955        self.assertIsNone(self.c.set_notice_receiver(None))
1956        self.assertIsNone(self.c.get_notice_receiver())
1957
1958    def testNoticeReceiver(self):
1959        self.addCleanup(self.c.query, 'drop function bilbo_notice();')
1960        self.c.query('''create function bilbo_notice() returns void AS $$
1961            begin
1962                raise warning 'Bilbo was here!';
1963            end;
1964            $$ language plpgsql''')
1965        received = {}
1966
1967        def notice_receiver(notice):
1968            for attr in dir(notice):
1969                if attr.startswith('__'):
1970                    continue
1971                value = getattr(notice, attr)
1972                if isinstance(value, str):
1973                    value = value.replace('WARNUNG', 'WARNING')
1974                received[attr] = value
1975
1976        self.c.set_notice_receiver(notice_receiver)
1977        self.c.query('select bilbo_notice()')
1978        self.assertEqual(received, dict(
1979            pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
1980            severity='WARNING', primary='Bilbo was here!',
1981            detail=None, hint=None))
1982
1983
1984class TestConfigFunctions(unittest.TestCase):
1985    """Test the functions for changing default settings.
1986
1987    To test the effect of most of these functions, we need a database
1988    connection.  That's why they are covered in this test module.
1989    """
1990
1991    def setUp(self):
1992        self.c = connect()
1993        self.c.query("set client_encoding=utf8")
1994        self.c.query('set bytea_output=hex')
1995        self.c.query("set lc_monetary='C'")
1996
1997    def tearDown(self):
1998        self.c.close()
1999
2000    def testGetDecimalPoint(self):
2001        point = pg.get_decimal_point()
2002        # error if a parameter is passed
2003        self.assertRaises(TypeError, pg.get_decimal_point, point)
2004        self.assertIsInstance(point, str)
2005        self.assertEqual(point, '.')  # the default setting
2006        pg.set_decimal_point(',')
2007        try:
2008            r = pg.get_decimal_point()
2009        finally:
2010            pg.set_decimal_point(point)
2011        self.assertIsInstance(r, str)
2012        self.assertEqual(r, ',')
2013        pg.set_decimal_point("'")
2014        try:
2015            r = pg.get_decimal_point()
2016        finally:
2017            pg.set_decimal_point(point)
2018        self.assertIsInstance(r, str)
2019        self.assertEqual(r, "'")
2020        pg.set_decimal_point('')
2021        try:
2022            r = pg.get_decimal_point()
2023        finally:
2024            pg.set_decimal_point(point)
2025        self.assertIsNone(r)
2026        pg.set_decimal_point(None)
2027        try:
2028            r = pg.get_decimal_point()
2029        finally:
2030            pg.set_decimal_point(point)
2031        self.assertIsNone(r)
2032
2033    def testSetDecimalPoint(self):
2034        d = pg.Decimal
2035        point = pg.get_decimal_point()
2036        self.assertRaises(TypeError, pg.set_decimal_point)
2037        # error if decimal point is not a string
2038        self.assertRaises(TypeError, pg.set_decimal_point, 0)
2039        # error if more than one decimal point passed
2040        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
2041        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
2042        # error if decimal point is not a punctuation character
2043        self.assertRaises(TypeError, pg.set_decimal_point, '0')
2044        query = self.c.query
2045        # check that money values are interpreted as decimal values
2046        # only if decimal_point is set, and that the result is correct
2047        # only if it is set suitable for the current lc_monetary setting
2048        select_money = "select '34.25'::money"
2049        proper_money = d('34.25')
2050        bad_money = d('3425')
2051        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
2052        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
2053        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
2054        de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25',
2055            'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM')
2056        # first try with English localization (using the point)
2057        for lc in en_locales:
2058            try:
2059                query("set lc_monetary='%s'" % lc)
2060            except pg.DataError:
2061                pass
2062            else:
2063                break
2064        else:
2065            self.skipTest("cannot set English money locale")
2066        try:
2067            query(select_money)
2068        except (pg.DataError, pg.ProgrammingError):
2069            # this can happen if the currency signs cannot be
2070            # converted using the encoding of the test database
2071            self.skipTest("database does not support English money")
2072        pg.set_decimal_point(None)
2073        try:
2074            r = query(select_money).getresult()[0][0]
2075        finally:
2076            pg.set_decimal_point(point)
2077        self.assertIsInstance(r, str)
2078        self.assertIn(r, en_money)
2079        pg.set_decimal_point('')
2080        try:
2081            r = query(select_money).getresult()[0][0]
2082        finally:
2083            pg.set_decimal_point(point)
2084        self.assertIsInstance(r, str)
2085        self.assertIn(r, en_money)
2086        pg.set_decimal_point('.')
2087        try:
2088            r = query(select_money).getresult()[0][0]
2089        finally:
2090            pg.set_decimal_point(point)
2091        self.assertIsInstance(r, d)
2092        self.assertEqual(r, proper_money)
2093        pg.set_decimal_point(',')
2094        try:
2095            r = query(select_money).getresult()[0][0]
2096        finally:
2097            pg.set_decimal_point(point)
2098        self.assertIsInstance(r, d)
2099        self.assertEqual(r, bad_money)
2100        pg.set_decimal_point("'")
2101        try:
2102            r = query(select_money).getresult()[0][0]
2103        finally:
2104            pg.set_decimal_point(point)
2105        self.assertIsInstance(r, d)
2106        self.assertEqual(r, bad_money)
2107        # then try with German localization (using the comma)
2108        for lc in de_locales:
2109            try:
2110                query("set lc_monetary='%s'" % lc)
2111            except pg.DataError:
2112                pass
2113            else:
2114                break
2115        else:
2116            self.skipTest("cannot set German money locale")
2117        select_money = select_money.replace('.', ',')
2118        try:
2119            query(select_money)
2120        except (pg.DataError, pg.ProgrammingError):
2121            self.skipTest("database does not support German money")
2122        pg.set_decimal_point(None)
2123        try:
2124            r = query(select_money).getresult()[0][0]
2125        finally:
2126            pg.set_decimal_point(point)
2127        self.assertIsInstance(r, str)
2128        self.assertIn(r, de_money)
2129        pg.set_decimal_point('')
2130        try:
2131            r = query(select_money).getresult()[0][0]
2132        finally:
2133            pg.set_decimal_point(point)
2134        self.assertIsInstance(r, str)
2135        self.assertIn(r, de_money)
2136        pg.set_decimal_point(',')
2137        try:
2138            r = query(select_money).getresult()[0][0]
2139        finally:
2140            pg.set_decimal_point(point)
2141        self.assertIsInstance(r, d)
2142        self.assertEqual(r, proper_money)
2143        pg.set_decimal_point('.')
2144        try:
2145            r = query(select_money).getresult()[0][0]
2146        finally:
2147            pg.set_decimal_point(point)
2148        self.assertEqual(r, bad_money)
2149        pg.set_decimal_point("'")
2150        try:
2151            r = query(select_money).getresult()[0][0]
2152        finally:
2153            pg.set_decimal_point(point)
2154        self.assertEqual(r, bad_money)
2155
2156    def testGetDecimal(self):
2157        decimal_class = pg.get_decimal()
2158        # error if a parameter is passed
2159        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
2160        self.assertIs(decimal_class, pg.Decimal)  # the default setting
2161        pg.set_decimal(int)
2162        try:
2163            r = pg.get_decimal()
2164        finally:
2165            pg.set_decimal(decimal_class)
2166        self.assertIs(r, int)
2167        r = pg.get_decimal()
2168        self.assertIs(r, decimal_class)
2169
2170    def testSetDecimal(self):
2171        decimal_class = pg.get_decimal()
2172        # error if no parameter is passed
2173        self.assertRaises(TypeError, pg.set_decimal)
2174        query = self.c.query
2175        try:
2176            r = query("select 3425::numeric")
2177        except pg.DatabaseError:
2178            self.skipTest('database does not support numeric')
2179        r = r.getresult()[0][0]
2180        self.assertIsInstance(r, decimal_class)
2181        self.assertEqual(r, decimal_class('3425'))
2182        r = query("select 3425::numeric")
2183        pg.set_decimal(int)
2184        try:
2185            r = r.getresult()[0][0]
2186        finally:
2187            pg.set_decimal(decimal_class)
2188        self.assertNotIsInstance(r, decimal_class)
2189        self.assertIsInstance(r, int)
2190        self.assertEqual(r, int(3425))
2191
2192    def testGetBool(self):
2193        use_bool = pg.get_bool()
2194        # error if a parameter is passed
2195        self.assertRaises(TypeError, pg.get_bool, use_bool)
2196        self.assertIsInstance(use_bool, bool)
2197        self.assertIs(use_bool, True)  # the default setting
2198        pg.set_bool(False)
2199        try:
2200            r = pg.get_bool()
2201        finally:
2202            pg.set_bool(use_bool)
2203        self.assertIsInstance(r, bool)
2204        self.assertIs(r, False)
2205        pg.set_bool(True)
2206        try:
2207            r = pg.get_bool()
2208        finally:
2209            pg.set_bool(use_bool)
2210        self.assertIsInstance(r, bool)
2211        self.assertIs(r, True)
2212        pg.set_bool(0)
2213        try:
2214            r = pg.get_bool()
2215        finally:
2216            pg.set_bool(use_bool)
2217        self.assertIsInstance(r, bool)
2218        self.assertIs(r, False)
2219        pg.set_bool(1)
2220        try:
2221            r = pg.get_bool()
2222        finally:
2223            pg.set_bool(use_bool)
2224        self.assertIsInstance(r, bool)
2225        self.assertIs(r, True)
2226
2227    def testSetBool(self):
2228        use_bool = pg.get_bool()
2229        # error if no parameter is passed
2230        self.assertRaises(TypeError, pg.set_bool)
2231        query = self.c.query
2232        try:
2233            r = query("select true::bool")
2234        except pg.ProgrammingError:
2235            self.skipTest('database does not support bool')
2236        r = r.getresult()[0][0]
2237        self.assertIsInstance(r, bool)
2238        self.assertEqual(r, True)
2239        pg.set_bool(False)
2240        try:
2241            r = query("select true::bool").getresult()[0][0]
2242        finally:
2243            pg.set_bool(use_bool)
2244        self.assertIsInstance(r, str)
2245        self.assertIs(r, 't')
2246        pg.set_bool(True)
2247        try:
2248            r = query("select true::bool").getresult()[0][0]
2249        finally:
2250            pg.set_bool(use_bool)
2251        self.assertIsInstance(r, bool)
2252        self.assertIs(r, True)
2253
2254    def testGetByteEscaped(self):
2255        bytea_escaped = pg.get_bytea_escaped()
2256        # error if a parameter is passed
2257        self.assertRaises(TypeError, pg.get_bytea_escaped, bytea_escaped)
2258        self.assertIsInstance(bytea_escaped, bool)
2259        self.assertIs(bytea_escaped, False)  # the default setting
2260        pg.set_bytea_escaped(True)
2261        try:
2262            r = pg.get_bytea_escaped()
2263        finally:
2264            pg.set_bytea_escaped(bytea_escaped)
2265        self.assertIsInstance(r, bool)
2266        self.assertIs(r, True)
2267        pg.set_bytea_escaped(False)
2268        try:
2269            r = pg.get_bytea_escaped()
2270        finally:
2271            pg.set_bytea_escaped(bytea_escaped)
2272        self.assertIsInstance(r, bool)
2273        self.assertIs(r, False)
2274        pg.set_bytea_escaped(1)
2275        try:
2276            r = pg.get_bytea_escaped()
2277        finally:
2278            pg.set_bytea_escaped(bytea_escaped)
2279        self.assertIsInstance(r, bool)
2280        self.assertIs(r, True)
2281        pg.set_bytea_escaped(0)
2282        try:
2283            r = pg.get_bytea_escaped()
2284        finally:
2285            pg.set_bytea_escaped(bytea_escaped)
2286        self.assertIsInstance(r, bool)
2287        self.assertIs(r, False)
2288
2289    def testSetByteaEscaped(self):
2290        bytea_escaped = pg.get_bytea_escaped()
2291        # error if no parameter is passed
2292        self.assertRaises(TypeError, pg.set_bytea_escaped)
2293        query = self.c.query
2294        try:
2295            r = query("select 'data'::bytea")
2296        except pg.ProgrammingError:
2297            self.skipTest('database does not support bytea')
2298        r = r.getresult()[0][0]
2299        self.assertIsInstance(r, bytes)
2300        self.assertEqual(r, b'data')
2301        pg.set_bytea_escaped(True)
2302        try:
2303            r = query("select 'data'::bytea").getresult()[0][0]
2304        finally:
2305            pg.set_bytea_escaped(bytea_escaped)
2306        self.assertIsInstance(r, str)
2307        self.assertEqual(r, '\\x64617461')
2308        pg.set_bytea_escaped(False)
2309        try:
2310            r = query("select 'data'::bytea").getresult()[0][0]
2311        finally:
2312            pg.set_bytea_escaped(bytea_escaped)
2313        self.assertIsInstance(r, bytes)
2314        self.assertEqual(r, b'data')
2315
2316    def testSetRowFactorySize(self):
2317        try:
2318            from functools import lru_cache
2319        except ImportError:  # Python < 3.2
2320            lru_cache = None
2321        queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc']
2322        query = self.c.query
2323        for maxsize in (None, 0, 1, 2, 3, 10, 1024):
2324            pg.set_row_factory_size(maxsize)
2325            for i in range(3):
2326                for q in queries:
2327                    r = query(q).namedresult()[0]
2328                    if q.endswith('abc'):
2329                        self.assertEqual(r, (123,))
2330                        self.assertEqual(r._fields, ('abc',))
2331                    else:
2332                        self.assertEqual(r, (1, 2, 3))
2333                        self.assertEqual(r._fields, ('a', 'b', 'c'))
2334            if lru_cache:
2335                info = pg._row_factory.cache_info()
2336                self.assertEqual(info.maxsize, maxsize)
2337                self.assertEqual(info.hits + info.misses, 6)
2338                self.assertEqual(info.hits,
2339                    0 if maxsize is not None and maxsize < 2 else 4)
2340
2341
2342class TestStandaloneEscapeFunctions(unittest.TestCase):
2343    """Test pg escape functions.
2344
2345    The libpq interface memorizes some parameters of the last opened
2346    connection that influence the result of these functions.  Therefore
2347    we need to open a connection with fixed parameters prior to testing
2348    in order to ensure that the tests always run under the same conditions.
2349    That's why these tests are included in this test module.
2350    """
2351
2352    cls_set_up = False
2353
2354    @classmethod
2355    def setUpClass(cls):
2356        db = connect()
2357        query = db.query
2358        query('set client_encoding=sql_ascii')
2359        query('set standard_conforming_strings=off')
2360        try:
2361            query('set bytea_output=escape')
2362        except pg.ProgrammingError:
2363            if db.server_version >= 90000:
2364                raise  # ignore for older server versions
2365        db.close()
2366        cls.cls_set_up = True
2367
2368    def testEscapeString(self):
2369        self.assertTrue(self.cls_set_up)
2370        f = pg.escape_string
2371        r = f(b'plain')
2372        self.assertIsInstance(r, bytes)
2373        self.assertEqual(r, b'plain')
2374        r = f(u'plain')
2375        self.assertIsInstance(r, unicode)
2376        self.assertEqual(r, u'plain')
2377        r = f(u"das is' kÀse".encode('utf-8'))
2378        self.assertIsInstance(r, bytes)
2379        self.assertEqual(r, u"das is'' kÀse".encode('utf-8'))
2380        r = f(u"that's cheesy")
2381        self.assertIsInstance(r, unicode)
2382        self.assertEqual(r, u"that''s cheesy")
2383        r = f(r"It's bad to have a \ inside.")
2384        self.assertEqual(r, r"It''s bad to have a \\ inside.")
2385
2386    def testEscapeBytea(self):
2387        self.assertTrue(self.cls_set_up)
2388        f = pg.escape_bytea
2389        r = f(b'plain')
2390        self.assertIsInstance(r, bytes)
2391        self.assertEqual(r, b'plain')
2392        r = f(u'plain')
2393        self.assertIsInstance(r, unicode)
2394        self.assertEqual(r, u'plain')
2395        r = f(u"das is' kÀse".encode('utf-8'))
2396        self.assertIsInstance(r, bytes)
2397        self.assertEqual(r, b"das is'' k\\\\303\\\\244se")
2398        r = f(u"that's cheesy")
2399        self.assertIsInstance(r, unicode)
2400        self.assertEqual(r, u"that''s cheesy")
2401        r = f(b'O\x00ps\xff!')
2402        self.assertEqual(r, b'O\\\\000ps\\\\377!')
2403
2404
2405if __name__ == '__main__':
2406    unittest.main()
Note: See TracBrowser for help on using the repository browser.