source: trunk/tests/test_classic_connection.py

Last change on this file was 1025, checked in by cito, 6 weeks ago

Make tests run with PostgreSQL 12

Skip/adapt tests that use tables with oids

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 87.9 KB
Line 
1#!/usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17import threading
18import time
19import os
20
21from collections import namedtuple, Iterable
22from decimal import Decimal
23
24import pg  # the module under test
25
26# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
27# get our information from that.  Otherwise we use the defaults.
28# These tests should be run with various PostgreSQL versions and databases
29# created with different encodings and locales.  Particularly, make sure the
30# tests are running against databases created with both SQL_ASCII and UTF8.
31dbname = 'unittest'
32dbhost = None
33dbport = 5432
34
35try:
36    from .LOCAL_PyGreSQL import *
37except (ImportError, ValueError):
38    try:
39        from LOCAL_PyGreSQL import *
40    except ImportError:
41        pass
42
43try:  # noinspection PyUnresolvedReferences
44    long
45except NameError:  # Python >= 3.0
46    long = int
47
48try:  # noinspection PyUnresolvedReferences
49    unicode
50except NameError:  # Python >= 3.0
51    unicode = str
52
53unicode_strings = str is not bytes
54
55windows = os.name == 'nt'
56
57# There is a known a bug in libpq under Windows which can cause
58# the interface to crash when calling PQhost():
59do_not_ask_for_host = windows
60do_not_ask_for_host_reason = 'libpq issue on Windows'
61
62
63def connect():
64    """Create a basic pg connection to the test database."""
65    connection = pg.connect(dbname, dbhost, dbport)
66    connection.query("set client_min_messages=warning")
67    return connection
68
69
70class TestCanConnect(unittest.TestCase):
71    """Test whether a basic connection to PostgreSQL is possible."""
72
73    def testCanConnect(self):
74        try:
75            connection = connect()
76        except pg.Error as error:
77            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
78        try:
79            connection.close()
80        except pg.Error:
81            self.fail('Cannot close the database connection')
82
83
84class TestConnectObject(unittest.TestCase):
85    """Test existence of basic pg connection methods."""
86
87    def setUp(self):
88        self.connection = connect()
89
90    def tearDown(self):
91        try:
92            self.connection.close()
93        except pg.InternalError:
94            pass
95
96    def is_method(self, attribute):
97        """Check if given attribute on the connection is a method."""
98        if do_not_ask_for_host and attribute == 'host':
99            return False
100        return callable(getattr(self.connection, attribute))
101
102    def testClassName(self):
103        self.assertEqual(self.connection.__class__.__name__, 'Connection')
104
105    def testModuleName(self):
106        self.assertEqual(self.connection.__class__.__module__, 'pg')
107
108    def testStr(self):
109        r = str(self.connection)
110        self.assertTrue(r.startswith('<pg.Connection object'), r)
111
112    def testRepr(self):
113        r = repr(self.connection)
114        self.assertTrue(r.startswith('<pg.Connection object'), r)
115
116    def testAllConnectAttributes(self):
117        attributes = '''backend_pid db error host options port
118            protocol_version server_version socket
119            ssl_attributes ssl_in_use status user'''.split()
120        connection_attributes = [a for a in dir(self.connection)
121            if not a.startswith('__') and not self.is_method(a)]
122        self.assertEqual(attributes, connection_attributes)
123
124    def testAllConnectMethods(self):
125        methods = '''
126            cancel close date_format describe_prepared endcopy
127            escape_bytea escape_identifier escape_literal escape_string
128            fileno get_cast_hook get_notice_receiver getline getlo getnotify
129            inserttable locreate loimport parameter
130            prepare putline query query_prepared reset
131            set_cast_hook set_notice_receiver source transaction
132            '''.split()
133        connection_methods = [a for a in dir(self.connection)
134            if not a.startswith('__') and self.is_method(a)]
135        self.assertEqual(methods, connection_methods)
136
137    def testAttributeDb(self):
138        self.assertEqual(self.connection.db, dbname)
139
140    def testAttributeError(self):
141        error = self.connection.error
142        self.assertTrue(not error or 'krb5_' in error)
143
144    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
145    def testAttributeHost(self):
146        if dbhost and not dbhost.startswith('/'):
147            host = dbhost
148        else:
149            host = 'localhost'
150        self.assertIsInstance(self.connection.host, str)
151        self.assertEqual(self.connection.host, host)
152
153    def testAttributeOptions(self):
154        no_options = ''
155        self.assertEqual(self.connection.options, no_options)
156
157    def testAttributePort(self):
158        def_port = 5432
159        self.assertIsInstance(self.connection.port, int)
160        self.assertEqual(self.connection.port, dbport or def_port)
161
162    def testAttributeProtocolVersion(self):
163        protocol_version = self.connection.protocol_version
164        self.assertIsInstance(protocol_version, int)
165        self.assertTrue(2 <= protocol_version < 4)
166
167    def testAttributeServerVersion(self):
168        server_version = self.connection.server_version
169        self.assertIsInstance(server_version, int)
170        self.assertTrue(90000 <= server_version < 130000)
171
172    def testAttributeSocket(self):
173        socket = self.connection.socket
174        self.assertIsInstance(socket, int)
175        self.assertGreaterEqual(socket, 0)
176
177    def testAttributeBackendPid(self):
178        backend_pid = self.connection.backend_pid
179        self.assertIsInstance(backend_pid, int)
180        self.assertGreaterEqual(backend_pid, 1)
181
182    def testAttributeSslInUse(self):
183        ssl_in_use = self.connection.ssl_in_use
184        self.assertIsInstance(ssl_in_use, bool)
185        self.assertFalse(ssl_in_use)
186
187    def testAttributeSslAttributes(self):
188        ssl_attributes = self.connection.ssl_attributes
189        self.assertIsInstance(ssl_attributes, dict)
190        self.assertEqual(ssl_attributes, {
191            'cipher': None, 'compression': None, 'key_bits': None,
192            'library': None, 'protocol': None})
193
194    def testAttributeStatus(self):
195        status_ok = 1
196        self.assertIsInstance(self.connection.status, int)
197        self.assertEqual(self.connection.status, status_ok)
198
199    def testAttributeUser(self):
200        no_user = 'Deprecated facility'
201        user = self.connection.user
202        self.assertTrue(user)
203        self.assertIsInstance(user, str)
204        self.assertNotEqual(user, no_user)
205
206    def testMethodQuery(self):
207        query = self.connection.query
208        query("select 1+1")
209        query("select 1+$1", (1,))
210        query("select 1+$1+$2", (2, 3))
211        query("select 1+$1+$2", [2, 3])
212
213    def testMethodQueryEmpty(self):
214        self.assertRaises(ValueError, self.connection.query, '')
215
216    def testAllQueryMembers(self):
217        query = self.connection.query("select true where false")
218        members = '''
219            dictiter dictresult fieldname fieldnum getresult listfields
220            namediter namedresult ntuples one onedict onenamed onescalar
221            scalariter scalarresult single singledict singlenamed singlescalar
222            '''.split()
223        query_members = [a for a in dir(query)
224            if not a.startswith('__')
225            and a != 'next']  # this is only needed in Python 2
226        self.assertEqual(members, query_members)
227
228    def testMethodEndcopy(self):
229        try:
230            self.connection.endcopy()
231        except IOError:
232            pass
233
234    def testMethodClose(self):
235        self.connection.close()
236        try:
237            self.connection.reset()
238        except (pg.Error, TypeError):
239            pass
240        else:
241            self.fail('Reset should give an error for a closed connection')
242        self.assertRaises(pg.InternalError, self.connection.close)
243        try:
244            self.connection.query('select 1')
245        except (pg.Error, TypeError):
246            pass
247        else:
248            self.fail('Query should give an error for a closed connection')
249        self.connection = connect()
250
251    def testMethodReset(self):
252        query = self.connection.query
253        # check that client encoding gets reset
254        encoding = query('show client_encoding').getresult()[0][0].upper()
255        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
256        self.assertNotEqual(encoding, changed_encoding)
257        self.connection.query("set client_encoding=%s" % changed_encoding)
258        new_encoding = query('show client_encoding').getresult()[0][0].upper()
259        self.assertEqual(new_encoding, changed_encoding)
260        self.connection.reset()
261        new_encoding = query('show client_encoding').getresult()[0][0].upper()
262        self.assertNotEqual(new_encoding, changed_encoding)
263        self.assertEqual(new_encoding, encoding)
264
265    def testMethodCancel(self):
266        r = self.connection.cancel()
267        self.assertIsInstance(r, int)
268        self.assertEqual(r, 1)
269
270    def testCancelLongRunningThread(self):
271        errors = []
272
273        def sleep():
274            try:
275                self.connection.query('select pg_sleep(5)').getresult()
276            except pg.DatabaseError as error:
277                errors.append(str(error))
278
279        thread = threading.Thread(target=sleep)
280        t1 = time.time()
281        thread.start()  # run the query
282        while 1:  # make sure the query is really running
283            time.sleep(0.1)
284            if thread.is_alive() or time.time() - t1 > 5:
285                break
286        r = self.connection.cancel()  # cancel the running query
287        thread.join()  # wait for the thread to end
288        t2 = time.time()
289
290        self.assertIsInstance(r, int)
291        self.assertEqual(r, 1)  # return code should be 1
292        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
293        self.assertTrue(errors)
294
295    def testMethodFileNo(self):
296        r = self.connection.fileno()
297        self.assertIsInstance(r, int)
298        self.assertGreaterEqual(r, 0)
299
300    def testMethodTransaction(self):
301        transaction = self.connection.transaction
302        self.assertRaises(TypeError, transaction, None)
303        self.assertEqual(transaction(), pg.TRANS_IDLE)
304        self.connection.query('begin')
305        self.assertEqual(transaction(), pg.TRANS_INTRANS)
306        self.connection.query('rollback')
307        self.assertEqual(transaction(), pg.TRANS_IDLE)
308
309    def testMethodParameter(self):
310        parameter = self.connection.parameter
311        query = self.connection.query
312        self.assertRaises(TypeError, parameter)
313        r = parameter('this server setting does not exist')
314        self.assertIsNone(r)
315        s = query('show server_version').getresult()[0][0]
316        self.assertIsNotNone(s)
317        r = parameter('server_version')
318        self.assertEqual(r, s)
319        s = query('show server_encoding').getresult()[0][0]
320        self.assertIsNotNone(s)
321        r = parameter('server_encoding')
322        self.assertEqual(r, s)
323        s = query('show client_encoding').getresult()[0][0]
324        self.assertIsNotNone(s)
325        r = parameter('client_encoding')
326        self.assertEqual(r, s)
327        s = query('show server_encoding').getresult()[0][0]
328        self.assertIsNotNone(s)
329        r = parameter('server_encoding')
330        self.assertEqual(r, s)
331
332
333class TestSimpleQueries(unittest.TestCase):
334    """Test simple queries via a basic pg connection."""
335
336    def setUp(self):
337        self.c = connect()
338
339    def tearDown(self):
340        self.doCleanups()
341        self.c.close()
342
343    def testClassName(self):
344        r = self.c.query("select 1")
345        self.assertEqual(r.__class__.__name__, 'Query')
346
347    def testModuleName(self):
348        r = self.c.query("select 1")
349        self.assertEqual(r.__class__.__module__, 'pg')
350
351    def testStr(self):
352        q = ("select 1 as a, 'hello' as h, 'w' as world"
353            " union select 2, 'xyz', 'uvw'")
354        r = self.c.query(q)
355        self.assertEqual(str(r),
356            'a|  h  |world\n'
357            '-+-----+-----\n'
358            '1|hello|w    \n'
359            '2|xyz  |uvw  \n'
360            '(2 rows)')
361
362    def testRepr(self):
363        r = repr(self.c.query("select 1"))
364        self.assertTrue(r.startswith('<pg.Query object'), r)
365
366    def testSelect0(self):
367        q = "select 0"
368        self.c.query(q)
369
370    def testSelect0Semicolon(self):
371        q = "select 0;"
372        self.c.query(q)
373
374    def testSelectDotSemicolon(self):
375        q = "select .;"
376        self.assertRaises(pg.DatabaseError, self.c.query, q)
377
378    def testGetresult(self):
379        q = "select 0"
380        result = [(0,)]
381        r = self.c.query(q).getresult()
382        self.assertIsInstance(r, list)
383        v = r[0]
384        self.assertIsInstance(v, tuple)
385        self.assertIsInstance(v[0], int)
386        self.assertEqual(r, result)
387
388    def testGetresultLong(self):
389        q = "select 9876543210"
390        result = long(9876543210)
391        self.assertIsInstance(result, long)
392        v = self.c.query(q).getresult()[0][0]
393        self.assertIsInstance(v, long)
394        self.assertEqual(v, result)
395
396    def testGetresultDecimal(self):
397        q = "select 98765432109876543210"
398        result = Decimal(98765432109876543210)
399        v = self.c.query(q).getresult()[0][0]
400        self.assertIsInstance(v, Decimal)
401        self.assertEqual(v, result)
402
403    def testGetresultString(self):
404        result = 'Hello, world!'
405        q = "select '%s'" % result
406        v = self.c.query(q).getresult()[0][0]
407        self.assertIsInstance(v, str)
408        self.assertEqual(v, result)
409
410    def testDictresult(self):
411        q = "select 0 as alias0"
412        result = [{'alias0': 0}]
413        r = self.c.query(q).dictresult()
414        self.assertIsInstance(r, list)
415        v = r[0]
416        self.assertIsInstance(v, dict)
417        self.assertIsInstance(v['alias0'], int)
418        self.assertEqual(r, result)
419
420    def testDictresultLong(self):
421        q = "select 9876543210 as longjohnsilver"
422        result = long(9876543210)
423        self.assertIsInstance(result, long)
424        v = self.c.query(q).dictresult()[0]['longjohnsilver']
425        self.assertIsInstance(v, long)
426        self.assertEqual(v, result)
427
428    def testDictresultDecimal(self):
429        q = "select 98765432109876543210 as longjohnsilver"
430        result = Decimal(98765432109876543210)
431        v = self.c.query(q).dictresult()[0]['longjohnsilver']
432        self.assertIsInstance(v, Decimal)
433        self.assertEqual(v, result)
434
435    def testDictresultString(self):
436        result = 'Hello, world!'
437        q = "select '%s' as greeting" % result
438        v = self.c.query(q).dictresult()[0]['greeting']
439        self.assertIsInstance(v, str)
440        self.assertEqual(v, result)
441
442    def testNamedresult(self):
443        q = "select 0 as alias0"
444        result = [(0,)]
445        r = self.c.query(q).namedresult()
446        self.assertEqual(r, result)
447        v = r[0]
448        self.assertEqual(v._fields, ('alias0',))
449        self.assertEqual(v.alias0, 0)
450
451    def testNamedresultWithGoodFieldnames(self):
452        q = 'select 1 as snake_case_alias, 2 as "CamelCaseAlias"'
453        result = [(1, 2)]
454        r = self.c.query(q).namedresult()
455        self.assertEqual(r, result)
456        v = r[0]
457        self.assertEqual(v._fields, ('snake_case_alias', 'CamelCaseAlias'))
458
459    def testNamedresultWithBadFieldnames(self):
460        try:
461            r = namedtuple('Bad', ['?'] * 6, rename=True)
462        except TypeError:  # Python 2.6 or 3.0
463            fields = tuple('column_%d' % n for n in range(6))
464        else:
465            fields = r._fields
466        q = ('select 3 as "0alias", 4 as _alias, 5 as "alias$", 6 as "alias?",'
467            ' 7 as "kebap-case-alias", 8 as break, 9 as and_a_good_one')
468        result = [tuple(range(3, 10))]
469        r = self.c.query(q).namedresult()
470        self.assertEqual(r, result)
471        v = r[0]
472        self.assertEqual(v._fields[:6], fields)
473        self.assertEqual(v._fields[6], 'and_a_good_one')
474
475    def testGet3Cols(self):
476        q = "select 1,2,3"
477        result = [(1, 2, 3)]
478        r = self.c.query(q).getresult()
479        self.assertEqual(r, result)
480
481    def testGet3DictCols(self):
482        q = "select 1 as a,2 as b,3 as c"
483        result = [dict(a=1, b=2, c=3)]
484        r = self.c.query(q).dictresult()
485        self.assertEqual(r, result)
486
487    def testGet3NamedCols(self):
488        q = "select 1 as a,2 as b,3 as c"
489        result = [(1, 2, 3)]
490        r = self.c.query(q).namedresult()
491        self.assertEqual(r, result)
492        v = r[0]
493        self.assertEqual(v._fields, ('a', 'b', 'c'))
494        self.assertEqual(v.b, 2)
495
496    def testGet3Rows(self):
497        q = "select 3 union select 1 union select 2 order by 1"
498        result = [(1,), (2,), (3,)]
499        r = self.c.query(q).getresult()
500        self.assertEqual(r, result)
501
502    def testGet3DictRows(self):
503        q = ("select 3 as alias3"
504            " union select 1 union select 2 order by 1")
505        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
506        r = self.c.query(q).dictresult()
507        self.assertEqual(r, result)
508
509    def testGet3NamedRows(self):
510        q = ("select 3 as alias3"
511            " union select 1 union select 2 order by 1")
512        result = [(1,), (2,), (3,)]
513        r = self.c.query(q).namedresult()
514        self.assertEqual(r, result)
515        for v in r:
516            self.assertEqual(v._fields, ('alias3',))
517
518    def testDictresultNames(self):
519        q = "select 'MixedCase' as MixedCaseAlias"
520        result = [{'mixedcasealias': 'MixedCase'}]
521        r = self.c.query(q).dictresult()
522        self.assertEqual(r, result)
523        q = "select 'MixedCase' as \"MixedCaseAlias\""
524        result = [{'MixedCaseAlias': 'MixedCase'}]
525        r = self.c.query(q).dictresult()
526        self.assertEqual(r, result)
527
528    def testNamedresultNames(self):
529        q = "select 'MixedCase' as MixedCaseAlias"
530        result = [('MixedCase',)]
531        r = self.c.query(q).namedresult()
532        self.assertEqual(r, result)
533        v = r[0]
534        self.assertEqual(v._fields, ('mixedcasealias',))
535        self.assertEqual(v.mixedcasealias, 'MixedCase')
536        q = "select 'MixedCase' as \"MixedCaseAlias\""
537        r = self.c.query(q).namedresult()
538        self.assertEqual(r, result)
539        v = r[0]
540        self.assertEqual(v._fields, ('MixedCaseAlias',))
541        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
542
543    def testBigGetresult(self):
544        num_cols = 100
545        num_rows = 100
546        q = "select " + ','.join(map(str, range(num_cols)))
547        q = ' union all '.join((q,) * num_rows)
548        r = self.c.query(q).getresult()
549        result = [tuple(range(num_cols))] * num_rows
550        self.assertEqual(r, result)
551
552    def testListfields(self):
553        q = ('select 0 as a, 0 as b, 0 as c,'
554            ' 0 as c, 0 as b, 0 as a,'
555            ' 0 as lowercase, 0 as UPPERCASE,'
556            ' 0 as MixedCase, 0 as "MixedCase",'
557            ' 0 as a_long_name_with_underscores,'
558            ' 0 as "A long name with Blanks"')
559        r = self.c.query(q).listfields()
560        self.assertIsInstance(r, tuple)
561        result = ('a', 'b', 'c', 'c', 'b', 'a',
562            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
563            'a_long_name_with_underscores',
564            'A long name with Blanks')
565        self.assertEqual(r, result)
566
567    def testFieldname(self):
568        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
569        r = self.c.query(q).fieldname(2)
570        self.assertEqual(r, 'x')
571        r = self.c.query(q).fieldname(3)
572        self.assertEqual(r, 'y')
573
574    def testFieldnum(self):
575        q = "select 1 as x"
576        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
577        q = "select 1 as x"
578        r = self.c.query(q).fieldnum('x')
579        self.assertIsInstance(r, int)
580        self.assertEqual(r, 0)
581        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
582        r = self.c.query(q).fieldnum('x')
583        self.assertIsInstance(r, int)
584        self.assertEqual(r, 2)
585        r = self.c.query(q).fieldnum('y')
586        self.assertIsInstance(r, int)
587        self.assertEqual(r, 3)
588
589    def testNtuples(self):  # deprecated
590        q = "select 1 where false"
591        r = self.c.query(q).ntuples()
592        self.assertIsInstance(r, int)
593        self.assertEqual(r, 0)
594        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
595            " union select 5 as a, 6 as b, 7 as c, 8 as d")
596        r = self.c.query(q).ntuples()
597        self.assertIsInstance(r, int)
598        self.assertEqual(r, 2)
599        q = ("select 1 union select 2 union select 3"
600            " union select 4 union select 5 union select 6")
601        r = self.c.query(q).ntuples()
602        self.assertIsInstance(r, int)
603        self.assertEqual(r, 6)
604
605    def testLen(self):
606        q = "select 1 where false"
607        self.assertEqual(len(self.c.query(q)), 0)
608        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
609            " union select 5 as a, 6 as b, 7 as c, 8 as d")
610        self.assertEqual(len(self.c.query(q)), 2)
611        q = ("select 1 union select 2 union select 3"
612            " union select 4 union select 5 union select 6")
613        self.assertEqual(len(self.c.query(q)), 6)
614
615    def testQuery(self):
616        query = self.c.query
617        query("drop table if exists test_table")
618        self.addCleanup(query, "drop table test_table")
619        q = "create table test_table (n integer)"
620        r = query(q)
621        self.assertIsNone(r)
622        q = "insert into test_table values (1)"
623        r = query(q)
624        self.assertIsInstance(r, str)
625        self.assertEqual(r, '1')
626        q = "insert into test_table select 2"
627        r = query(q)
628        self.assertIsInstance(r, str)
629        self.assertEqual(r, '1')
630        q = "select n from test_table where n>1"
631        r = query(q).getresult()
632        self.assertEqual(len(r), 1)
633        r = r[0]
634        self.assertEqual(len(r), 1)
635        r = r[0]
636        self.assertIsInstance(r, int)
637        self.assertEqual(r, 2)
638        q = "insert into test_table select 3 union select 4 union select 5"
639        r = query(q)
640        self.assertIsInstance(r, str)
641        self.assertEqual(r, '3')
642        q = "update test_table set n=4 where n<5"
643        r = query(q)
644        self.assertIsInstance(r, str)
645        self.assertEqual(r, '4')
646        q = "delete from test_table"
647        r = query(q)
648        self.assertIsInstance(r, str)
649        self.assertEqual(r, '5')
650
651    def testQueryWithOids(self):
652        if self.c.server_version >= 120000:
653            self.skipTest("database does not support tables with oids")
654        query = self.c.query
655        query("drop table if exists test_table")
656        self.addCleanup(query, "drop table test_table")
657        q = "create table test_table (n integer) with oids"
658        r = query(q)
659        self.assertIsNone(r)
660        q = "insert into test_table values (1)"
661        r = query(q)
662        self.assertIsInstance(r, int)
663        q = "insert into test_table select 2"
664        r = query(q)
665        self.assertIsInstance(r, int)
666        oid = r
667        q = "select oid from test_table where n=2"
668        r = query(q).getresult()
669        self.assertEqual(len(r), 1)
670        r = r[0]
671        self.assertEqual(len(r), 1)
672        r = r[0]
673        self.assertIsInstance(r, int)
674        self.assertEqual(r, oid)
675        q = "insert into test_table select 3 union select 4 union select 5"
676        r = query(q)
677        self.assertIsInstance(r, str)
678        self.assertEqual(r, '3')
679        q = "update test_table set n=4 where n<5"
680        r = query(q)
681        self.assertIsInstance(r, str)
682        self.assertEqual(r, '4')
683        q = "delete from test_table"
684        r = query(q)
685        self.assertIsInstance(r, str)
686        self.assertEqual(r, '5')
687
688
689class TestUnicodeQueries(unittest.TestCase):
690    """Test unicode strings as queries via a basic pg connection."""
691
692    def setUp(self):
693        self.c = connect()
694        self.c.query('set client_encoding=utf8')
695
696    def tearDown(self):
697        self.c.close()
698
699    def testGetresulAscii(self):
700        result = u'Hello, world!'
701        q = u"select '%s'" % result
702        v = self.c.query(q).getresult()[0][0]
703        self.assertIsInstance(v, str)
704        self.assertEqual(v, result)
705
706    def testDictresulAscii(self):
707        result = u'Hello, world!'
708        q = u"select '%s' as greeting" % result
709        v = self.c.query(q).dictresult()[0]['greeting']
710        self.assertIsInstance(v, str)
711        self.assertEqual(v, result)
712
713    def testGetresultUtf8(self):
714        result = u'Hello, wörld & ЌОр!'
715        q = u"select '%s'" % result
716        if not unicode_strings:
717            result = result.encode('utf8')
718        # pass the query as unicode
719        try:
720            v = self.c.query(q).getresult()[0][0]
721        except(pg.DataError, pg.NotSupportedError):
722            self.skipTest("database does not support utf8")
723        self.assertIsInstance(v, str)
724        self.assertEqual(v, result)
725        q = q.encode('utf8')
726        # pass the query as bytes
727        v = self.c.query(q).getresult()[0][0]
728        self.assertIsInstance(v, str)
729        self.assertEqual(v, result)
730
731    def testDictresultUtf8(self):
732        result = u'Hello, wörld & ЌОр!'
733        q = u"select '%s' as greeting" % result
734        if not unicode_strings:
735            result = result.encode('utf8')
736        try:
737            v = self.c.query(q).dictresult()[0]['greeting']
738        except (pg.DataError, pg.NotSupportedError):
739            self.skipTest("database does not support utf8")
740        self.assertIsInstance(v, str)
741        self.assertEqual(v, result)
742        q = q.encode('utf8')
743        v = self.c.query(q).dictresult()[0]['greeting']
744        self.assertIsInstance(v, str)
745        self.assertEqual(v, result)
746
747    def testDictresultLatin1(self):
748        try:
749            self.c.query('set client_encoding=latin1')
750        except (pg.DataError, pg.NotSupportedError):
751            self.skipTest("database does not support latin1")
752        result = u'Hello, wörld!'
753        q = u"select '%s'" % result
754        if not unicode_strings:
755            result = result.encode('latin1')
756        v = self.c.query(q).getresult()[0][0]
757        self.assertIsInstance(v, str)
758        self.assertEqual(v, result)
759        q = q.encode('latin1')
760        v = self.c.query(q).getresult()[0][0]
761        self.assertIsInstance(v, str)
762        self.assertEqual(v, result)
763
764    def testDictresultLatin1(self):
765        try:
766            self.c.query('set client_encoding=latin1')
767        except (pg.DataError, pg.NotSupportedError):
768            self.skipTest("database does not support latin1")
769        result = u'Hello, wörld!'
770        q = u"select '%s' as greeting" % result
771        if not unicode_strings:
772            result = result.encode('latin1')
773        v = self.c.query(q).dictresult()[0]['greeting']
774        self.assertIsInstance(v, str)
775        self.assertEqual(v, result)
776        q = q.encode('latin1')
777        v = self.c.query(q).dictresult()[0]['greeting']
778        self.assertIsInstance(v, str)
779        self.assertEqual(v, result)
780
781    def testGetresultCyrillic(self):
782        try:
783            self.c.query('set client_encoding=iso_8859_5')
784        except (pg.DataError, pg.NotSupportedError):
785            self.skipTest("database does not support cyrillic")
786        result = u'Hello, ЌОр!'
787        q = u"select '%s'" % result
788        if not unicode_strings:
789            result = result.encode('cyrillic')
790        v = self.c.query(q).getresult()[0][0]
791        self.assertIsInstance(v, str)
792        self.assertEqual(v, result)
793        q = q.encode('cyrillic')
794        v = self.c.query(q).getresult()[0][0]
795        self.assertIsInstance(v, str)
796        self.assertEqual(v, result)
797
798    def testDictresultCyrillic(self):
799        try:
800            self.c.query('set client_encoding=iso_8859_5')
801        except (pg.DataError, pg.NotSupportedError):
802            self.skipTest("database does not support cyrillic")
803        result = u'Hello, ЌОр!'
804        q = u"select '%s' as greeting" % result
805        if not unicode_strings:
806            result = result.encode('cyrillic')
807        v = self.c.query(q).dictresult()[0]['greeting']
808        self.assertIsInstance(v, str)
809        self.assertEqual(v, result)
810        q = q.encode('cyrillic')
811        v = self.c.query(q).dictresult()[0]['greeting']
812        self.assertIsInstance(v, str)
813        self.assertEqual(v, result)
814
815    def testGetresultLatin9(self):
816        try:
817            self.c.query('set client_encoding=latin9')
818        except (pg.DataError, pg.NotSupportedError):
819            self.skipTest("database does not support latin9")
820        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
821        q = u"select '%s'" % result
822        if not unicode_strings:
823            result = result.encode('latin9')
824        v = self.c.query(q).getresult()[0][0]
825        self.assertIsInstance(v, str)
826        self.assertEqual(v, result)
827        q = q.encode('latin9')
828        v = self.c.query(q).getresult()[0][0]
829        self.assertIsInstance(v, str)
830        self.assertEqual(v, result)
831
832    def testDictresultLatin9(self):
833        try:
834            self.c.query('set client_encoding=latin9')
835        except (pg.DataError, pg.NotSupportedError):
836            self.skipTest("database does not support latin9")
837        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
838        q = u"select '%s' as menu" % result
839        if not unicode_strings:
840            result = result.encode('latin9')
841        v = self.c.query(q).dictresult()[0]['menu']
842        self.assertIsInstance(v, str)
843        self.assertEqual(v, result)
844        q = q.encode('latin9')
845        v = self.c.query(q).dictresult()[0]['menu']
846        self.assertIsInstance(v, str)
847        self.assertEqual(v, result)
848
849
850class TestParamQueries(unittest.TestCase):
851    """Test queries with parameters via a basic pg connection."""
852
853    def setUp(self):
854        self.c = connect()
855        self.c.query('set client_encoding=utf8')
856
857    def tearDown(self):
858        self.c.close()
859
860    def testQueryWithNoneParam(self):
861        self.assertRaises(TypeError, self.c.query, "select $1", None)
862        self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None)
863        self.assertEqual(self.c.query("select $1::integer", (None,)
864            ).getresult(), [(None,)])
865        self.assertEqual(self.c.query("select $1::text", [None]
866            ).getresult(), [(None,)])
867        self.assertEqual(self.c.query("select $1::text", [[None]]
868            ).getresult(), [(None,)])
869
870    def testQueryWithBoolParams(self, bool_enabled=None):
871        query = self.c.query
872        if bool_enabled is not None:
873            bool_enabled_default = pg.get_bool()
874            pg.set_bool(bool_enabled)
875        try:
876            bool_on = bool_enabled or bool_enabled is None
877            v_false, v_true = (False, True) if bool_on else 'ft'
878            r_false, r_true = [(v_false,)], [(v_true,)]
879            self.assertEqual(query("select false").getresult(), r_false)
880            self.assertEqual(query("select true").getresult(), r_true)
881            q = "select $1::bool"
882            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
883            self.assertEqual(query(q, ('f',)).getresult(), r_false)
884            self.assertEqual(query(q, ('t',)).getresult(), r_true)
885            self.assertEqual(query(q, ('false',)).getresult(), r_false)
886            self.assertEqual(query(q, ('true',)).getresult(), r_true)
887            self.assertEqual(query(q, ('n',)).getresult(), r_false)
888            self.assertEqual(query(q, ('y',)).getresult(), r_true)
889            self.assertEqual(query(q, (0,)).getresult(), r_false)
890            self.assertEqual(query(q, (1,)).getresult(), r_true)
891            self.assertEqual(query(q, (False,)).getresult(), r_false)
892            self.assertEqual(query(q, (True,)).getresult(), r_true)
893        finally:
894            if bool_enabled is not None:
895                pg.set_bool(bool_enabled_default)
896
897    def testQueryWithBoolParamsNotDefault(self):
898        self.testQueryWithBoolParams(bool_enabled=not pg.get_bool())
899
900    def testQueryWithIntParams(self):
901        query = self.c.query
902        self.assertEqual(query("select 1+1").getresult(), [(2,)])
903        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
904        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
905        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
906        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
907        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
908            [(Decimal('2'),)])
909        self.assertEqual(query("select 1, $1::integer", (2,)
910            ).getresult(), [(1, 2)])
911        self.assertEqual(query("select 1 union select $1::integer", (2,)
912            ).getresult(), [(1,), (2,)])
913        self.assertEqual(query("select $1::integer+$2", (1, 2)
914            ).getresult(), [(3,)])
915        self.assertEqual(query("select $1::integer+$2", [1, 2]
916            ).getresult(), [(3,)])
917        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))
918            ).getresult(), [(15,)])
919
920    def testQueryWithStrParams(self):
921        query = self.c.query
922        self.assertEqual(query("select $1||', world!'", ('Hello',)
923            ).getresult(), [('Hello, world!',)])
924        self.assertEqual(query("select $1||', world!'", ['Hello']
925            ).getresult(), [('Hello, world!',)])
926        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
927            ).getresult(), [('Hello, world!',)])
928        self.assertEqual(query("select $1::text", ('Hello, world!',)
929            ).getresult(), [('Hello, world!',)])
930        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
931            ).getresult(), [('Hello', 'world')])
932        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
933            ).getresult(), [('Hello', 'world')])
934        self.assertEqual(query("select $1::text union select $2::text",
935            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
936        try:
937            query("select 'wörld'")
938        except (pg.DataError, pg.NotSupportedError):
939            self.skipTest('database does not support utf8')
940        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
941            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
942
943    def testQueryWithUnicodeParams(self):
944        query = self.c.query
945        try:
946            query('set client_encoding=utf8')
947            query("select 'wörld'").getresult()[0][0] == 'wörld'
948        except (pg.DataError, pg.NotSupportedError):
949            self.skipTest("database does not support utf8")
950        self.assertEqual(query("select $1||', '||$2||'!'",
951            ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)])
952
953    def testQueryWithUnicodeParamsLatin1(self):
954        query = self.c.query
955        try:
956            query('set client_encoding=latin1')
957            query("select 'wörld'").getresult()[0][0] == 'wörld'
958        except (pg.DataError, pg.NotSupportedError):
959            self.skipTest("database does not support latin1")
960        r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult()
961        if unicode_strings:
962            self.assertEqual(r, [('Hello, wörld!',)])
963        else:
964            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
965        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
966            ('Hello', u'ЌОр'))
967        query('set client_encoding=iso_8859_1')
968        r = query("select $1||', '||$2||'!'",
969            ('Hello', u'wörld')).getresult()
970        if unicode_strings:
971            self.assertEqual(r, [('Hello, wörld!',)])
972        else:
973            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
974        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
975            ('Hello', u'ЌОр'))
976        query('set client_encoding=sql_ascii')
977        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
978            ('Hello', u'wörld'))
979
980    def testQueryWithUnicodeParamsCyrillic(self):
981        query = self.c.query
982        try:
983            query('set client_encoding=iso_8859_5')
984            query("select 'ЌОр'").getresult()[0][0] == 'ЌОр'
985        except (pg.DataError, pg.NotSupportedError):
986            self.skipTest("database does not support cyrillic")
987        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
988            ('Hello', u'wörld'))
989        r = query("select $1||', '||$2||'!'",
990            ('Hello', u'ЌОр')).getresult()
991        if unicode_strings:
992            self.assertEqual(r, [('Hello, ЌОр!',)])
993        else:
994            self.assertEqual(r, [(u'Hello, ЌОр!'.encode('cyrillic'),)])
995        query('set client_encoding=sql_ascii')
996        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
997            ('Hello', u'ЌОр!'))
998
999    def testQueryWithMixedParams(self):
1000        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
1001            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
1002        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
1003            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
1004
1005    def testQueryWithDuplicateParams(self):
1006        self.assertRaises(pg.ProgrammingError,
1007            self.c.query, "select $1+$1", (1,))
1008        self.assertRaises(pg.ProgrammingError,
1009            self.c.query, "select $1+$1", (1, 2))
1010
1011    def testQueryWithZeroParams(self):
1012        self.assertEqual(self.c.query("select 1+1", []
1013            ).getresult(), [(2,)])
1014
1015    def testQueryWithGarbage(self):
1016        garbage = r"'\{}+()-#[]oo324"
1017        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
1018            ).dictresult(), [{'garbage': garbage}])
1019
1020
1021class TestPreparedQueries(unittest.TestCase):
1022    """Test prepared queries via a basic pg connection."""
1023
1024    def setUp(self):
1025        self.c = connect()
1026        self.c.query('set client_encoding=utf8')
1027
1028    def tearDown(self):
1029        self.c.close()
1030
1031    def testEmptyPreparedStatement(self):
1032        self.c.prepare('', '')
1033        self.assertRaises(ValueError, self.c.query_prepared, '')
1034
1035    def testInvalidPreparedStatement(self):
1036        self.assertRaises(pg.ProgrammingError, self.c.prepare, '', 'bad')
1037
1038    def testDuplicatePreparedStatement(self):
1039        self.assertIsNone(self.c.prepare('q', 'select 1'))
1040        self.assertRaises(pg.ProgrammingError, self.c.prepare, 'q', 'select 2')
1041
1042    def testNonExistentPreparedStatement(self):
1043        self.assertRaises(pg.OperationalError,
1044            self.c.query_prepared, 'does-not-exist')
1045
1046    def testUnnamedQueryWithoutParams(self):
1047        self.assertIsNone(self.c.prepare('', "select 'anon'"))
1048        self.assertEqual(self.c.query_prepared('').getresult(), [('anon',)])
1049        self.assertEqual(self.c.query_prepared('').getresult(), [('anon',)])
1050
1051    def testNamedQueryWithoutParams(self):
1052        self.assertIsNone(self.c.prepare('hello', "select 'world'"))
1053        self.assertEqual(self.c.query_prepared('hello').getresult(),
1054            [('world',)])
1055
1056    def testMultipleNamedQueriesWithoutParams(self):
1057        self.assertIsNone(self.c.prepare('query17', "select 17"))
1058        self.assertIsNone(self.c.prepare('query42', "select 42"))
1059        self.assertEqual(self.c.query_prepared('query17').getresult(), [(17,)])
1060        self.assertEqual(self.c.query_prepared('query42').getresult(), [(42,)])
1061
1062    def testUnnamedQueryWithParams(self):
1063        self.assertIsNone(self.c.prepare('', "select $1 || ', ' || $2"))
1064        self.assertEqual(
1065            self.c.query_prepared('', ['hello', 'world']).getresult(),
1066            [('hello, world',)])
1067        self.assertIsNone(self.c.prepare('', "select 1+ $1 + $2 + $3"))
1068        self.assertEqual(
1069            self.c.query_prepared('', [17, -5, 29]).getresult(), [(42,)])
1070
1071    def testMultipleNamedQueriesWithParams(self):
1072        self.assertIsNone(self.c.prepare('q1', "select $1 || '!'"))
1073        self.assertIsNone(self.c.prepare('q2', "select $1 || '-' || $2"))
1074        self.assertEqual(self.c.query_prepared('q1', ['hello']).getresult(),
1075            [('hello!',)])
1076        self.assertEqual(self.c.query_prepared('q2', ['he', 'lo']).getresult(),
1077            [('he-lo',)])
1078
1079    def testDescribeNonExistentQuery(self):
1080        self.assertRaises(pg.OperationalError,
1081            self.c.describe_prepared, 'does-not-exist')
1082
1083    def testDescribeUnnamedQuery(self):
1084        self.c.prepare('', "select 1::int, 'a'::char")
1085        r = self.c.describe_prepared('')
1086        self.assertEqual(r.listfields(), ('int4', 'bpchar'))
1087
1088    def testDescribeNamedQuery(self):
1089        self.c.prepare('myquery', "select 1 as first, 2 as second")
1090        r = self.c.describe_prepared('myquery')
1091        self.assertEqual(r.listfields(), ('first', 'second'))
1092
1093    def testDescribeMultipleNamedQueries(self):
1094        self.c.prepare('query1', "select 1::int")
1095        self.c.prepare('query2', "select 1::int, 2::int")
1096        r = self.c.describe_prepared('query1')
1097        self.assertEqual(r.listfields(), ('int4',))
1098        r = self.c.describe_prepared('query2')
1099        self.assertEqual(r.listfields(), ('int4', 'int4'))
1100
1101
1102class TestQueryResultTypes(unittest.TestCase):
1103    """Test proper result types via a basic pg connection."""
1104
1105    def setUp(self):
1106        self.c = connect()
1107        self.c.query('set client_encoding=utf8')
1108        self.c.query("set datestyle='ISO,YMD'")
1109        self.c.query("set timezone='UTC'")
1110
1111    def tearDown(self):
1112        self.c.close()
1113
1114    def assert_proper_cast(self, value, pgtype, pytype):
1115        q = 'select $1::%s' % (pgtype,)
1116        try:
1117            r = self.c.query(q, (value,)).getresult()[0][0]
1118        except pg.ProgrammingError:
1119            if pgtype in ('json', 'jsonb'):
1120                self.skipTest('database does not support json')
1121        self.assertIsInstance(r, pytype)
1122        if isinstance(value, str):
1123            if not value or ' ' in value or '{' in value:
1124                value = '"%s"' % value
1125        value = '{%s}' % value
1126        r = self.c.query(q + '[]', (value,)).getresult()[0][0]
1127        if pgtype.startswith(('date', 'time', 'interval')):
1128            # arrays of these are casted by the DB wrapper only
1129            self.assertEqual(r, value)
1130        else:
1131            self.assertIsInstance(r, list)
1132            self.assertEqual(len(r), 1)
1133            self.assertIsInstance(r[0], pytype)
1134
1135    def testInt(self):
1136        self.assert_proper_cast(0, 'int', int)
1137        self.assert_proper_cast(0, 'smallint', int)
1138        self.assert_proper_cast(0, 'oid', int)
1139        self.assert_proper_cast(0, 'cid', int)
1140        self.assert_proper_cast(0, 'xid', int)
1141
1142    def testLong(self):
1143        self.assert_proper_cast(0, 'bigint', long)
1144
1145    def testFloat(self):
1146        self.assert_proper_cast(0, 'float', float)
1147        self.assert_proper_cast(0, 'real', float)
1148        self.assert_proper_cast(0, 'double', float)
1149        self.assert_proper_cast(0, 'double precision', float)
1150        self.assert_proper_cast('infinity', 'float', float)
1151
1152    def testFloat(self):
1153        decimal = pg.get_decimal()
1154        self.assert_proper_cast(decimal(0), 'numeric', decimal)
1155        self.assert_proper_cast(decimal(0), 'decimal', decimal)
1156
1157    def testMoney(self):
1158        decimal = pg.get_decimal()
1159        self.assert_proper_cast(decimal('0'), 'money', decimal)
1160
1161    def testBool(self):
1162        bool_type = bool if pg.get_bool() else str
1163        self.assert_proper_cast('f', 'bool', bool_type)
1164
1165    def testDate(self):
1166        self.assert_proper_cast('1956-01-31', 'date', str)
1167        self.assert_proper_cast('10:20:30', 'interval', str)
1168        self.assert_proper_cast('08:42:15', 'time', str)
1169        self.assert_proper_cast('08:42:15+00', 'timetz', str)
1170        self.assert_proper_cast('1956-01-31 08:42:15', 'timestamp', str)
1171        self.assert_proper_cast('1956-01-31 08:42:15+00', 'timestamptz', str)
1172
1173    def testText(self):
1174        self.assert_proper_cast('', 'text', str)
1175        self.assert_proper_cast('', 'char', str)
1176        self.assert_proper_cast('', 'bpchar', str)
1177        self.assert_proper_cast('', 'varchar', str)
1178
1179    def testBytea(self):
1180        self.assert_proper_cast('', 'bytea', bytes)
1181
1182    def testJson(self):
1183        self.assert_proper_cast('{}', 'json', dict)
1184
1185
1186class TestQueryIterator(unittest.TestCase):
1187    """Test the query operating as an iterator."""
1188
1189    def setUp(self):
1190        self.c = connect()
1191
1192    def tearDown(self):
1193        self.c.close()
1194
1195    def testLen(self):
1196        r = self.c.query("select generate_series(3,7)")
1197        self.assertEqual(len(r), 5)
1198
1199    def testGetItem(self):
1200        r = self.c.query("select generate_series(7,9)")
1201        self.assertEqual(r[0], (7,))
1202        self.assertEqual(r[1], (8,))
1203        self.assertEqual(r[2], (9,))
1204
1205    def testGetItemWithNegativeIndex(self):
1206        r = self.c.query("select generate_series(7,9)")
1207        self.assertEqual(r[-1], (9,))
1208        self.assertEqual(r[-2], (8,))
1209        self.assertEqual(r[-3], (7,))
1210
1211    def testGetItemOutOfRange(self):
1212        r = self.c.query("select generate_series(7,9)")
1213        self.assertRaises(IndexError, r.__getitem__, 3)
1214
1215    def testIterate(self):
1216        r = self.c.query("select generate_series(3,5)")
1217        self.assertNotIsInstance(r, (list, tuple))
1218        self.assertIsInstance(r, Iterable)
1219        self.assertEqual(list(r), [(3,), (4,), (5,)])
1220        self.assertIsInstance(r[1], tuple)
1221
1222    def testIterateTwice(self):
1223        r = self.c.query("select generate_series(3,5)")
1224        for i in range(2):
1225            self.assertEqual(list(r), [(3,), (4,), (5,)])
1226
1227    def testIterateTwoColumns(self):
1228        r = self.c.query("select 1,2 union select 3,4")
1229        self.assertIsInstance(r, Iterable)
1230        self.assertEqual(list(r), [(1, 2), (3, 4)])
1231
1232    def testNext(self):
1233        r = self.c.query("select generate_series(7,9)")
1234        self.assertEqual(next(r), (7,))
1235        self.assertEqual(next(r), (8,))
1236        self.assertEqual(next(r), (9,))
1237        self.assertRaises(StopIteration, next, r)
1238
1239    def testContains(self):
1240        r = self.c.query("select generate_series(7,9)")
1241        self.assertIn((8,), r)
1242        self.assertNotIn((5,), r)
1243
1244    def testDictIterate(self):
1245        r = self.c.query("select generate_series(3,5) as n").dictiter()
1246        self.assertNotIsInstance(r, (list, tuple))
1247        self.assertIsInstance(r, Iterable)
1248        r = list(r)
1249        self.assertEqual(r, [dict(n=3), dict(n=4), dict(n=5)])
1250        self.assertIsInstance(r[1], dict)
1251
1252    def testDictIterateTwoColumns(self):
1253        r = self.c.query("select 1 as one, 2 as two"
1254            " union select 3 as one, 4 as two").dictiter()
1255        self.assertIsInstance(r, Iterable)
1256        r = list(r)
1257        self.assertEqual(r, [dict(one=1, two=2), dict(one=3, two=4)])
1258
1259    def testDictNext(self):
1260        r = self.c.query("select generate_series(7,9) as n").dictiter()
1261        self.assertEqual(next(r), dict(n=7))
1262        self.assertEqual(next(r), dict(n=8))
1263        self.assertEqual(next(r), dict(n=9))
1264        self.assertRaises(StopIteration, next, r)
1265
1266    def testDictContains(self):
1267        r = self.c.query("select generate_series(7,9) as n").dictiter()
1268        self.assertIn(dict(n=8), r)
1269        self.assertNotIn(dict(n=5), r)
1270
1271    def testNamedIterate(self):
1272        r = self.c.query("select generate_series(3,5) as number").namediter()
1273        self.assertNotIsInstance(r, (list, tuple))
1274        self.assertIsInstance(r, Iterable)
1275        r = list(r)
1276        self.assertEqual(r, [(3,), (4,), (5,)])
1277        self.assertIsInstance(r[1], tuple)
1278        self.assertEqual(r[1]._fields, ('number',))
1279        self.assertEqual(r[1].number, 4)
1280
1281    def testNamedIterateTwoColumns(self):
1282        r = self.c.query("select 1 as one, 2 as two"
1283            " union select 3 as one, 4 as two").namediter()
1284        self.assertIsInstance(r, Iterable)
1285        r = list(r)
1286        self.assertEqual(r, [(1, 2), (3, 4)])
1287        self.assertEqual(r[0]._fields, ('one', 'two'))
1288        self.assertEqual(r[0].one, 1)
1289        self.assertEqual(r[1]._fields, ('one', 'two'))
1290        self.assertEqual(r[1].two, 4)
1291
1292    def testNamedNext(self):
1293        r = self.c.query("select generate_series(7,9) as number").namediter()
1294        self.assertEqual(next(r), (7,))
1295        self.assertEqual(next(r), (8,))
1296        n = next(r)
1297        self.assertEqual(n._fields, ('number',))
1298        self.assertEqual(n.number, 9)
1299        self.assertRaises(StopIteration, next, r)
1300
1301    def testNamedContains(self):
1302        r = self.c.query("select generate_series(7,9)").namediter()
1303        self.assertIn((8,), r)
1304        self.assertNotIn((5,), r)
1305
1306    def testScalarIterate(self):
1307        r = self.c.query("select generate_series(3,5)").scalariter()
1308        self.assertNotIsInstance(r, (list, tuple))
1309        self.assertIsInstance(r, Iterable)
1310        r = list(r)
1311        self.assertEqual(r, [3, 4, 5])
1312        self.assertIsInstance(r[1], int)
1313
1314    def testScalarIterateTwoColumns(self):
1315        r = self.c.query("select 1, 2 union select 3, 4").scalariter()
1316        self.assertIsInstance(r, Iterable)
1317        r = list(r)
1318        self.assertEqual(r, [1, 3])
1319
1320    def testScalarNext(self):
1321        r = self.c.query("select generate_series(7,9)").scalariter()
1322        self.assertEqual(next(r), 7)
1323        self.assertEqual(next(r), 8)
1324        self.assertEqual(next(r), 9)
1325        self.assertRaises(StopIteration, next, r)
1326
1327    def testScalarContains(self):
1328        r = self.c.query("select generate_series(7,9)").scalariter()
1329        self.assertIn(8, r)
1330        self.assertNotIn(5, r)
1331
1332
1333class TestQueryOneSingleScalar(unittest.TestCase):
1334    """Test the query methods for getting single rows and columns."""
1335
1336    def setUp(self):
1337        self.c = connect()
1338
1339    def tearDown(self):
1340        self.c.close()
1341
1342    def testOneWithEmptyQuery(self):
1343        q = self.c.query("select 0 where false")
1344        self.assertIsNone(q.one())
1345
1346    def testOneWithSingleRow(self):
1347        q = self.c.query("select 1, 2")
1348        r = q.one()
1349        self.assertIsInstance(r, tuple)
1350        self.assertEqual(r, (1, 2))
1351        self.assertEqual(q.one(), None)
1352
1353    def testOneWithTwoRows(self):
1354        q = self.c.query("select 1, 2 union select 3, 4")
1355        self.assertEqual(q.one(), (1, 2))
1356        self.assertEqual(q.one(), (3, 4))
1357        self.assertEqual(q.one(), None)
1358
1359    def testOneDictWithEmptyQuery(self):
1360        q = self.c.query("select 0 where false")
1361        self.assertIsNone(q.onedict())
1362
1363    def testOneDictWithSingleRow(self):
1364        q = self.c.query("select 1 as one, 2 as two")
1365        r = q.onedict()
1366        self.assertIsInstance(r, dict)
1367        self.assertEqual(r, dict(one=1, two=2))
1368        self.assertEqual(q.onedict(), None)
1369
1370    def testOneDictWithTwoRows(self):
1371        q = self.c.query(
1372            "select 1 as one, 2 as two union select 3 as one, 4 as two")
1373        self.assertEqual(q.onedict(), dict(one=1, two=2))
1374        self.assertEqual(q.onedict(), dict(one=3, two=4))
1375        self.assertEqual(q.onedict(), None)
1376
1377    def testOneNamedWithEmptyQuery(self):
1378        q = self.c.query("select 0 where false")
1379        self.assertIsNone(q.onenamed())
1380
1381    def testOneNamedWithSingleRow(self):
1382        q = self.c.query("select 1 as one, 2 as two")
1383        r = q.onenamed()
1384        self.assertEqual(r._fields, ('one', 'two'))
1385        self.assertEqual(r.one, 1)
1386        self.assertEqual(r.two, 2)
1387        self.assertEqual(r, (1, 2))
1388        self.assertEqual(q.onenamed(), None)
1389
1390    def testOneNamedWithTwoRows(self):
1391        q = self.c.query(
1392            "select 1 as one, 2 as two union select 3 as one, 4 as two")
1393        r = q.onenamed()
1394        self.assertEqual(r._fields, ('one', 'two'))
1395        self.assertEqual(r.one, 1)
1396        self.assertEqual(r.two, 2)
1397        self.assertEqual(r, (1, 2))
1398        r = q.onenamed()
1399        self.assertEqual(r._fields, ('one', 'two'))
1400        self.assertEqual(r.one, 3)
1401        self.assertEqual(r.two, 4)
1402        self.assertEqual(r, (3, 4))
1403        self.assertEqual(q.onenamed(), None)
1404
1405    def testOneScalarWithEmptyQuery(self):
1406        q = self.c.query("select 0 where false")
1407        self.assertIsNone(q.onescalar())
1408
1409    def testOneScalarWithSingleRow(self):
1410        q = self.c.query("select 1, 2")
1411        r = q.onescalar()
1412        self.assertIsInstance(r, int)
1413        self.assertEqual(r, 1)
1414        self.assertEqual(q.onescalar(), None)
1415
1416    def testOneScalarWithTwoRows(self):
1417        q = self.c.query("select 1, 2 union select 3, 4")
1418        self.assertEqual(q.onescalar(), 1)
1419        self.assertEqual(q.onescalar(), 3)
1420        self.assertEqual(q.onescalar(), None)
1421
1422    def testSingleWithEmptyQuery(self):
1423        q = self.c.query("select 0 where false")
1424        try:
1425            q.single()
1426        except pg.InvalidResultError as e:
1427            r = e
1428        else:
1429            r = None
1430        self.assertIsInstance(r, pg.NoResultError)
1431        self.assertEqual(str(r), 'No result found')
1432
1433    def testSingleWithSingleRow(self):
1434        q = self.c.query("select 1, 2")
1435        r = q.single()
1436        self.assertIsInstance(r, tuple)
1437        self.assertEqual(r, (1, 2))
1438        r = q.single()
1439        self.assertIsInstance(r, tuple)
1440        self.assertEqual(r, (1, 2))
1441
1442    def testSingleWithTwoRows(self):
1443        q = self.c.query("select 1, 2 union select 3, 4")
1444        try:
1445            q.single()
1446        except pg.InvalidResultError as e:
1447            r = e
1448        else:
1449            r = None
1450        self.assertIsInstance(r, pg.MultipleResultsError)
1451        self.assertEqual(str(r), 'Multiple results found')
1452
1453    def testSingleDictWithEmptyQuery(self):
1454        q = self.c.query("select 0 where false")
1455        try:
1456            q.singledict()
1457        except pg.InvalidResultError as e:
1458            r = e
1459        else:
1460            r = None
1461        self.assertIsInstance(r, pg.NoResultError)
1462        self.assertEqual(str(r), 'No result found')
1463
1464    def testSingleDictWithSingleRow(self):
1465        q = self.c.query("select 1 as one, 2 as two")
1466        r = q.singledict()
1467        self.assertIsInstance(r, dict)
1468        self.assertEqual(r, dict(one=1, two=2))
1469        r = q.singledict()
1470        self.assertIsInstance(r, dict)
1471        self.assertEqual(r, dict(one=1, two=2))
1472
1473    def testSingleDictWithTwoRows(self):
1474        q = self.c.query("select 1, 2 union select 3, 4")
1475        try:
1476            q.singledict()
1477        except pg.InvalidResultError as e:
1478            r = e
1479        else:
1480            r = None
1481        self.assertIsInstance(r, pg.MultipleResultsError)
1482        self.assertEqual(str(r), 'Multiple results found')
1483
1484    def testSingleNamedWithEmptyQuery(self):
1485        q = self.c.query("select 0 where false")
1486        try:
1487            q.singlenamed()
1488        except pg.InvalidResultError as e:
1489            r = e
1490        else:
1491            r = None
1492        self.assertIsInstance(r, pg.NoResultError)
1493        self.assertEqual(str(r), 'No result found')
1494
1495    def testSingleNamedWithSingleRow(self):
1496        q = self.c.query("select 1 as one, 2 as two")
1497        r = q.singlenamed()
1498        self.assertEqual(r._fields, ('one', 'two'))
1499        self.assertEqual(r.one, 1)
1500        self.assertEqual(r.two, 2)
1501        self.assertEqual(r, (1, 2))
1502        r = q.singlenamed()
1503        self.assertEqual(r._fields, ('one', 'two'))
1504        self.assertEqual(r.one, 1)
1505        self.assertEqual(r.two, 2)
1506        self.assertEqual(r, (1, 2))
1507
1508    def testSingleNamedWithTwoRows(self):
1509        q = self.c.query("select 1, 2 union select 3, 4")
1510        try:
1511            q.singlenamed()
1512        except pg.InvalidResultError as e:
1513            r = e
1514        else:
1515            r = None
1516        self.assertIsInstance(r, pg.MultipleResultsError)
1517        self.assertEqual(str(r), 'Multiple results found')
1518
1519    def testSingleScalarWithEmptyQuery(self):
1520        q = self.c.query("select 0 where false")
1521        try:
1522            q.singlescalar()
1523        except pg.InvalidResultError as e:
1524            r = e
1525        else:
1526            r = None
1527        self.assertIsInstance(r, pg.NoResultError)
1528        self.assertEqual(str(r), 'No result found')
1529
1530    def testSingleScalarWithSingleRow(self):
1531        q = self.c.query("select 1, 2")
1532        r = q.singlescalar()
1533        self.assertIsInstance(r, int)
1534        self.assertEqual(r, 1)
1535        r = q.singlescalar()
1536        self.assertIsInstance(r, int)
1537        self.assertEqual(r, 1)
1538
1539    def testSingleWithTwoRows(self):
1540        q = self.c.query("select 1, 2 union select 3, 4")
1541        try:
1542            q.singlescalar()
1543        except pg.InvalidResultError as e:
1544            r = e
1545        else:
1546            r = None
1547        self.assertIsInstance(r, pg.MultipleResultsError)
1548        self.assertEqual(str(r), 'Multiple results found')
1549
1550    def testScalarResult(self):
1551        q = self.c.query("select 1, 2 union select 3, 4")
1552        r = q.scalarresult()
1553        self.assertIsInstance(r, list)
1554        self.assertEqual(r, [1, 3])
1555
1556    def testScalarIter(self):
1557        q = self.c.query("select 1, 2 union select 3, 4")
1558        r = q.scalariter()
1559        self.assertNotIsInstance(r, (list, tuple))
1560        self.assertIsInstance(r, Iterable)
1561        r = list(r)
1562        self.assertEqual(r, [1, 3])
1563
1564
1565class TestInserttable(unittest.TestCase):
1566    """Test inserttable method."""
1567
1568    cls_set_up = False
1569
1570    @classmethod
1571    def setUpClass(cls):
1572        c = connect()
1573        c.query("drop table if exists test cascade")
1574        c.query("create table test ("
1575            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
1576            "d numeric, f4 real, f8 double precision, m money,"
1577            "c char(1), v4 varchar(4), c4 char(4), t text)")
1578        # Check whether the test database uses SQL_ASCII - this means
1579        # that it does not consider encoding when calculating lengths.
1580        c.query("set client_encoding=utf8")
1581        try:
1582            c.query("select 'À'")
1583        except (pg.DataError, pg.NotSupportedError):
1584            cls.has_encoding = False
1585        else:
1586            cls.has_encoding = c.query(
1587                "select length('À') - length('a')").getresult()[0][0] == 0
1588        c.close()
1589        cls.cls_set_up = True
1590
1591    @classmethod
1592    def tearDownClass(cls):
1593        c = connect()
1594        c.query("drop table test cascade")
1595        c.close()
1596
1597    def setUp(self):
1598        self.assertTrue(self.cls_set_up)
1599        self.c = connect()
1600        self.c.query("set client_encoding=utf8")
1601        self.c.query("set datestyle='ISO,YMD'")
1602        self.c.query("set lc_monetary='C'")
1603
1604    def tearDown(self):
1605        self.c.query("truncate table test")
1606        self.c.close()
1607
1608    data = [
1609        (-1, -1, long(-1), True, '1492-10-12', '08:30:00',
1610            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
1611        (0, 0, long(0), False, '1607-04-14', '09:00:00',
1612            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
1613        (1, 1, long(1), True, '1801-03-04', '03:45:00',
1614            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
1615        (2, 2, long(2), False, '1903-12-17', '11:22:00',
1616            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
1617
1618    @classmethod
1619    def db_len(cls, s, encoding):
1620        if cls.has_encoding:
1621            s = s if isinstance(s, unicode) else s.decode(encoding)
1622        else:
1623            s = s.encode(encoding) if isinstance(s, unicode) else s
1624        return len(s)
1625
1626    def get_back(self, encoding='utf-8'):
1627        """Convert boolean and decimal values back."""
1628        data = []
1629        for row in self.c.query("select * from test order by 1").getresult():
1630            self.assertIsInstance(row, tuple)
1631            row = list(row)
1632            if row[0] is not None:  # smallint
1633                self.assertIsInstance(row[0], int)
1634            if row[1] is not None:  # integer
1635                self.assertIsInstance(row[1], int)
1636            if row[2] is not None:  # bigint
1637                self.assertIsInstance(row[2], long)
1638            if row[3] is not None:  # boolean
1639                self.assertIsInstance(row[3], bool)
1640            if row[4] is not None:  # date
1641                self.assertIsInstance(row[4], str)
1642                self.assertTrue(row[4].replace('-', '').isdigit())
1643            if row[5] is not None:  # time
1644                self.assertIsInstance(row[5], str)
1645                self.assertTrue(row[5].replace(':', '').isdigit())
1646            if row[6] is not None:  # numeric
1647                self.assertIsInstance(row[6], Decimal)
1648                row[6] = float(row[6])
1649            if row[7] is not None:  # real
1650                self.assertIsInstance(row[7], float)
1651            if row[8] is not None:  # double precision
1652                self.assertIsInstance(row[8], float)
1653                row[8] = float(row[8])
1654            if row[9] is not None:  # money
1655                self.assertIsInstance(row[9], Decimal)
1656                row[9] = str(float(row[9]))
1657            if row[10] is not None:  # char(1)
1658                self.assertIsInstance(row[10], str)
1659                self.assertEqual(self.db_len(row[10], encoding), 1)
1660            if row[11] is not None:  # varchar(4)
1661                self.assertIsInstance(row[11], str)
1662                self.assertLessEqual(self.db_len(row[11], encoding), 4)
1663            if row[12] is not None:  # char(4)
1664                self.assertIsInstance(row[12], str)
1665                self.assertEqual(self.db_len(row[12], encoding), 4)
1666                row[12] = row[12].rstrip()
1667            if row[13] is not None:  # text
1668                self.assertIsInstance(row[13], str)
1669            row = tuple(row)
1670            data.append(row)
1671        return data
1672
1673    def testInserttable1Row(self):
1674        data = self.data[2:3]
1675        self.c.inserttable('test', data)
1676        self.assertEqual(self.get_back(), data)
1677
1678    def testInserttable4Rows(self):
1679        data = self.data
1680        self.c.inserttable('test', data)
1681        self.assertEqual(self.get_back(), data)
1682
1683    def testInserttableFromTupleOfLists(self):
1684        data = tuple(list(row) for row in self.data)
1685        self.c.inserttable('test', data)
1686        self.assertEqual(self.get_back(), self.data)
1687
1688    def testInserttableFromSetofTuples(self):
1689        data = set(row for row in self.data)
1690        try:
1691            self.c.inserttable('test', data)
1692        except TypeError as e:
1693            r = str(e)
1694        else:
1695            r = 'this is fine'
1696        self.assertIn('list or a tuple as second argument', r)
1697
1698    def testInserttableFromListOfSets(self):
1699        data = [set(row) for row in self.data]
1700        try:
1701            self.c.inserttable('test', data)
1702        except TypeError as e:
1703            r = str(e)
1704        else:
1705            r = 'this is fine'
1706        self.assertIn('second argument must contain a tuple or a list', r)
1707
1708    def testInserttableMultipleRows(self):
1709        num_rows = 100
1710        data = self.data[2:3] * num_rows
1711        self.c.inserttable('test', data)
1712        r = self.c.query("select count(*) from test").getresult()[0][0]
1713        self.assertEqual(r, num_rows)
1714
1715    def testInserttableMultipleCalls(self):
1716        num_rows = 10
1717        data = self.data[2:3]
1718        for _i in range(num_rows):
1719            self.c.inserttable('test', data)
1720        r = self.c.query("select count(*) from test").getresult()[0][0]
1721        self.assertEqual(r, num_rows)
1722
1723    def testInserttableNullValues(self):
1724        data = [(None,) * 14] * 100
1725        self.c.inserttable('test', data)
1726        self.assertEqual(self.get_back(), data)
1727
1728    def testInserttableMaxValues(self):
1729        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
1730            True, '2999-12-31', '11:59:59', 1e99,
1731            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
1732            "1", "1234", "1234", "1234" * 100)]
1733        self.c.inserttable('test', data)
1734        self.assertEqual(self.get_back(), data)
1735
1736    def testInserttableByteValues(self):
1737        try:
1738            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1739        except pg.DataError:
1740            self.skipTest("database does not support utf8")
1741        # non-ascii chars do not fit in char(1) when there is no encoding
1742        c = u'€' if self.has_encoding else u'$'
1743        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1744            0.0, 0.0, 0.0, u'0.0',
1745            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1746        row_bytes = tuple(s.encode('utf-8')
1747            if isinstance(s, unicode) else s for s in row_unicode)
1748        data = [row_bytes] * 2
1749        self.c.inserttable('test', data)
1750        if unicode_strings:
1751            data = [row_unicode] * 2
1752        self.assertEqual(self.get_back(), data)
1753
1754    def testInserttableUnicodeUtf8(self):
1755        try:
1756            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1757        except pg.DataError:
1758            self.skipTest("database does not support utf8")
1759        # non-ascii chars do not fit in char(1) when there is no encoding
1760        c = u'€' if self.has_encoding else u'$'
1761        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1762            0.0, 0.0, 0.0, u'0.0',
1763            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1764        data = [row_unicode] * 2
1765        self.c.inserttable('test', data)
1766        if not unicode_strings:
1767            row_bytes = tuple(s.encode('utf-8')
1768                if isinstance(s, unicode) else s for s in row_unicode)
1769            data = [row_bytes] * 2
1770        self.assertEqual(self.get_back(), data)
1771
1772    def testInserttableUnicodeLatin1(self):
1773        try:
1774            self.c.query("set client_encoding=latin1")
1775            self.c.query("select 'Â¥'")
1776        except (pg.DataError, pg.NotSupportedError):
1777            self.skipTest("database does not support latin1")
1778        # non-ascii chars do not fit in char(1) when there is no encoding
1779        c = u'€' if self.has_encoding else u'$'
1780        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1781            0.0, 0.0, 0.0, u'0.0',
1782            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1783        data = [row_unicode]
1784        # cannot encode € sign with latin1 encoding
1785        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1786        row_unicode = tuple(s.replace(u'€', u'Â¥')
1787            if isinstance(s, unicode) else s for s in row_unicode)
1788        data = [row_unicode] * 2
1789        self.c.inserttable('test', data)
1790        if not unicode_strings:
1791            row_bytes = tuple(s.encode('latin1')
1792                if isinstance(s, unicode) else s for s in row_unicode)
1793            data = [row_bytes] * 2
1794        self.assertEqual(self.get_back('latin1'), data)
1795
1796    def testInserttableUnicodeLatin9(self):
1797        try:
1798            self.c.query("set client_encoding=latin9")
1799            self.c.query("select '€'")
1800        except (pg.DataError, pg.NotSupportedError):
1801            self.skipTest("database does not support latin9")
1802            return
1803        # non-ascii chars do not fit in char(1) when there is no encoding
1804        c = u'€' if self.has_encoding else u'$'
1805        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1806            0.0, 0.0, 0.0, u'0.0',
1807            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1808        data = [row_unicode] * 2
1809        self.c.inserttable('test', data)
1810        if not unicode_strings:
1811            row_bytes = tuple(s.encode('latin9')
1812                if isinstance(s, unicode) else s for s in row_unicode)
1813            data = [row_bytes] * 2
1814        self.assertEqual(self.get_back('latin9'), data)
1815
1816    def testInserttableNoEncoding(self):
1817        self.c.query("set client_encoding=sql_ascii")
1818        # non-ascii chars do not fit in char(1) when there is no encoding
1819        c = u'€' if self.has_encoding else u'$'
1820        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1821            0.0, 0.0, 0.0, u'0.0',
1822            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1823        data = [row_unicode]
1824        # cannot encode non-ascii unicode without a specific encoding
1825        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1826
1827
1828class TestDirectSocketAccess(unittest.TestCase):
1829    """Test copy command with direct socket access."""
1830
1831    cls_set_up = False
1832
1833    @classmethod
1834    def setUpClass(cls):
1835        c = connect()
1836        c.query("drop table if exists test cascade")
1837        c.query("create table test (i int, v varchar(16))")
1838        c.close()
1839        cls.cls_set_up = True
1840
1841    @classmethod
1842    def tearDownClass(cls):
1843        c = connect()
1844        c.query("drop table test cascade")
1845        c.close()
1846
1847    def setUp(self):
1848        self.assertTrue(self.cls_set_up)
1849        self.c = connect()
1850        self.c.query("set client_encoding=utf8")
1851
1852    def tearDown(self):
1853        self.c.query("truncate table test")
1854        self.c.close()
1855
1856    def testPutline(self):
1857        putline = self.c.putline
1858        query = self.c.query
1859        data = list(enumerate("apple pear plum cherry banana".split()))
1860        query("copy test from stdin")
1861        try:
1862            for i, v in data:
1863                putline("%d\t%s\n" % (i, v))
1864            putline("\\.\n")
1865        finally:
1866            self.c.endcopy()
1867        r = query("select * from test").getresult()
1868        self.assertEqual(r, data)
1869
1870    def testPutlineBytesAndUnicode(self):
1871        putline = self.c.putline
1872        query = self.c.query
1873        try:
1874            query("select 'kÀse+wÃŒrstel'")
1875        except (pg.DataError, pg.NotSupportedError):
1876            self.skipTest('database does not support utf8')
1877        query("copy test from stdin")
1878        try:
1879            putline(u"47\tkÀse\n".encode('utf8'))
1880            putline("35\twÃŒrstel\n")
1881            putline(b"\\.\n")
1882        finally:
1883            self.c.endcopy()
1884        r = query("select * from test").getresult()
1885        self.assertEqual(r, [(47, 'kÀse'), (35, 'wÃŒrstel')])
1886
1887    def testGetline(self):
1888        getline = self.c.getline
1889        query = self.c.query
1890        data = list(enumerate("apple banana pear plum strawberry".split()))
1891        n = len(data)
1892        self.c.inserttable('test', data)
1893        query("copy test to stdout")
1894        try:
1895            for i in range(n + 2):
1896                v = getline()
1897                if i < n:
1898                    self.assertEqual(v, '%d\t%s' % data[i])
1899                elif i == n:
1900                    self.assertEqual(v, '\\.')
1901                else:
1902                    self.assertIsNone(v)
1903        finally:
1904            try:
1905                self.c.endcopy()
1906            except IOError:
1907                pass
1908
1909    def testGetlineBytesAndUnicode(self):
1910        getline = self.c.getline
1911        query = self.c.query
1912        try:
1913            query("select 'kÀse+wÃŒrstel'")
1914        except (pg.DataError, pg.NotSupportedError):
1915            self.skipTest('database does not support utf8')
1916        data = [(54, u'kÀse'.encode('utf8')), (73, u'wÃŒrstel')]
1917        self.c.inserttable('test', data)
1918        query("copy test to stdout")
1919        try:
1920            v = getline()
1921            self.assertIsInstance(v, str)
1922            self.assertEqual(v, '54\tkÀse')
1923            v = getline()
1924            self.assertIsInstance(v, str)
1925            self.assertEqual(v, '73\twÃŒrstel')
1926            self.assertEqual(getline(), '\\.')
1927            self.assertIsNone(getline())
1928        finally:
1929            try:
1930                self.c.endcopy()
1931            except IOError:
1932                pass
1933
1934    def testParameterChecks(self):
1935        self.assertRaises(TypeError, self.c.putline)
1936        self.assertRaises(TypeError, self.c.getline, 'invalid')
1937        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
1938
1939
1940class TestNotificatons(unittest.TestCase):
1941    """Test notification support."""
1942
1943    def setUp(self):
1944        self.c = connect()
1945
1946    def tearDown(self):
1947        self.doCleanups()
1948        self.c.close()
1949
1950    def testGetNotify(self):
1951        getnotify = self.c.getnotify
1952        query = self.c.query
1953        self.assertIsNone(getnotify())
1954        query('listen test_notify')
1955        try:
1956            self.assertIsNone(self.c.getnotify())
1957            query("notify test_notify")
1958            r = getnotify()
1959            self.assertIsInstance(r, tuple)
1960            self.assertEqual(len(r), 3)
1961            self.assertIsInstance(r[0], str)
1962            self.assertIsInstance(r[1], int)
1963            self.assertIsInstance(r[2], str)
1964            self.assertEqual(r[0], 'test_notify')
1965            self.assertEqual(r[2], '')
1966            self.assertIsNone(self.c.getnotify())
1967            query("notify test_notify, 'test_payload'")
1968            r = getnotify()
1969            self.assertTrue(isinstance(r, tuple))
1970            self.assertEqual(len(r), 3)
1971            self.assertIsInstance(r[0], str)
1972            self.assertIsInstance(r[1], int)
1973            self.assertIsInstance(r[2], str)
1974            self.assertEqual(r[0], 'test_notify')
1975            self.assertEqual(r[2], 'test_payload')
1976            self.assertIsNone(getnotify())
1977        finally:
1978            query('unlisten test_notify')
1979
1980    def testGetNoticeReceiver(self):
1981        self.assertIsNone(self.c.get_notice_receiver())
1982
1983    def testSetNoticeReceiver(self):
1984        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
1985        self.assertRaises(TypeError, self.c.set_notice_receiver, 'invalid')
1986        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
1987        self.assertIsNone(self.c.set_notice_receiver(None))
1988
1989    def testSetAndGetNoticeReceiver(self):
1990        r = lambda notice: None
1991        self.assertIsNone(self.c.set_notice_receiver(r))
1992        self.assertIs(self.c.get_notice_receiver(), r)
1993        self.assertIsNone(self.c.set_notice_receiver(None))
1994        self.assertIsNone(self.c.get_notice_receiver())
1995
1996    def testNoticeReceiver(self):
1997        self.addCleanup(self.c.query, 'drop function bilbo_notice();')
1998        self.c.query('''create function bilbo_notice() returns void AS $$
1999            begin
2000                raise warning 'Bilbo was here!';
2001            end;
2002            $$ language plpgsql''')
2003        received = {}
2004
2005        def notice_receiver(notice):
2006            for attr in dir(notice):
2007                if attr.startswith('__'):
2008                    continue
2009                value = getattr(notice, attr)
2010                if isinstance(value, str):
2011                    value = value.replace('WARNUNG', 'WARNING')
2012                received[attr] = value
2013
2014        self.c.set_notice_receiver(notice_receiver)
2015        self.c.query('select bilbo_notice()')
2016        self.assertEqual(received, dict(
2017            pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
2018            severity='WARNING', primary='Bilbo was here!',
2019            detail=None, hint=None))
2020
2021
2022class TestConfigFunctions(unittest.TestCase):
2023    """Test the functions for changing default settings.
2024
2025    To test the effect of most of these functions, we need a database
2026    connection.  That's why they are covered in this test module.
2027    """
2028
2029    def setUp(self):
2030        self.c = connect()
2031        self.c.query("set client_encoding=utf8")
2032        self.c.query('set bytea_output=hex')
2033        self.c.query("set lc_monetary='C'")
2034
2035    def tearDown(self):
2036        self.c.close()
2037
2038    def testGetDecimalPoint(self):
2039        point = pg.get_decimal_point()
2040        # error if a parameter is passed
2041        self.assertRaises(TypeError, pg.get_decimal_point, point)
2042        self.assertIsInstance(point, str)
2043        self.assertEqual(point, '.')  # the default setting
2044        pg.set_decimal_point(',')
2045        try:
2046            r = pg.get_decimal_point()
2047        finally:
2048            pg.set_decimal_point(point)
2049        self.assertIsInstance(r, str)
2050        self.assertEqual(r, ',')
2051        pg.set_decimal_point("'")
2052        try:
2053            r = pg.get_decimal_point()
2054        finally:
2055            pg.set_decimal_point(point)
2056        self.assertIsInstance(r, str)
2057        self.assertEqual(r, "'")
2058        pg.set_decimal_point('')
2059        try:
2060            r = pg.get_decimal_point()
2061        finally:
2062            pg.set_decimal_point(point)
2063        self.assertIsNone(r)
2064        pg.set_decimal_point(None)
2065        try:
2066            r = pg.get_decimal_point()
2067        finally:
2068            pg.set_decimal_point(point)
2069        self.assertIsNone(r)
2070
2071    def testSetDecimalPoint(self):
2072        d = pg.Decimal
2073        point = pg.get_decimal_point()
2074        self.assertRaises(TypeError, pg.set_decimal_point)
2075        # error if decimal point is not a string
2076        self.assertRaises(TypeError, pg.set_decimal_point, 0)
2077        # error if more than one decimal point passed
2078        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
2079        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
2080        # error if decimal point is not a punctuation character
2081        self.assertRaises(TypeError, pg.set_decimal_point, '0')
2082        query = self.c.query
2083        # check that money values are interpreted as decimal values
2084        # only if decimal_point is set, and that the result is correct
2085        # only if it is set suitable for the current lc_monetary setting
2086        select_money = "select '34.25'::money"
2087        proper_money = d('34.25')
2088        bad_money = d('3425')
2089        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
2090        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
2091        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
2092        de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25',
2093            'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM')
2094        # first try with English localization (using the point)
2095        for lc in en_locales:
2096            try:
2097                query("set lc_monetary='%s'" % lc)
2098            except pg.DataError:
2099                pass
2100            else:
2101                break
2102        else:
2103            self.skipTest("cannot set English money locale")
2104        try:
2105            query(select_money)
2106        except (pg.DataError, pg.ProgrammingError):
2107            # this can happen if the currency signs cannot be
2108            # converted using the encoding of the test database
2109            self.skipTest("database does not support English money")
2110        pg.set_decimal_point(None)
2111        try:
2112            r = query(select_money).getresult()[0][0]
2113        finally:
2114            pg.set_decimal_point(point)
2115        self.assertIsInstance(r, str)
2116        self.assertIn(r, en_money)
2117        pg.set_decimal_point('')
2118        try:
2119            r = query(select_money).getresult()[0][0]
2120        finally:
2121            pg.set_decimal_point(point)
2122        self.assertIsInstance(r, str)
2123        self.assertIn(r, en_money)
2124        pg.set_decimal_point('.')
2125        try:
2126            r = query(select_money).getresult()[0][0]
2127        finally:
2128            pg.set_decimal_point(point)
2129        self.assertIsInstance(r, d)
2130        self.assertEqual(r, proper_money)
2131        pg.set_decimal_point(',')
2132        try:
2133            r = query(select_money).getresult()[0][0]
2134        finally:
2135            pg.set_decimal_point(point)
2136        self.assertIsInstance(r, d)
2137        self.assertEqual(r, bad_money)
2138        pg.set_decimal_point("'")
2139        try:
2140            r = query(select_money).getresult()[0][0]
2141        finally:
2142            pg.set_decimal_point(point)
2143        self.assertIsInstance(r, d)
2144        self.assertEqual(r, bad_money)
2145        # then try with German localization (using the comma)
2146        for lc in de_locales:
2147            try:
2148                query("set lc_monetary='%s'" % lc)
2149            except pg.DataError:
2150                pass
2151            else:
2152                break
2153        else:
2154            self.skipTest("cannot set German money locale")
2155        select_money = select_money.replace('.', ',')
2156        try:
2157            query(select_money)
2158        except (pg.DataError, pg.ProgrammingError):
2159            self.skipTest("database does not support German money")
2160        pg.set_decimal_point(None)
2161        try:
2162            r = query(select_money).getresult()[0][0]
2163        finally:
2164            pg.set_decimal_point(point)
2165        self.assertIsInstance(r, str)
2166        self.assertIn(r, de_money)
2167        pg.set_decimal_point('')
2168        try:
2169            r = query(select_money).getresult()[0][0]
2170        finally:
2171            pg.set_decimal_point(point)
2172        self.assertIsInstance(r, str)
2173        self.assertIn(r, de_money)
2174        pg.set_decimal_point(',')
2175        try:
2176            r = query(select_money).getresult()[0][0]
2177        finally:
2178            pg.set_decimal_point(point)
2179        self.assertIsInstance(r, d)
2180        self.assertEqual(r, proper_money)
2181        pg.set_decimal_point('.')
2182        try:
2183            r = query(select_money).getresult()[0][0]
2184        finally:
2185            pg.set_decimal_point(point)
2186        self.assertEqual(r, bad_money)
2187        pg.set_decimal_point("'")
2188        try:
2189            r = query(select_money).getresult()[0][0]
2190        finally:
2191            pg.set_decimal_point(point)
2192        self.assertEqual(r, bad_money)
2193
2194    def testGetDecimal(self):
2195        decimal_class = pg.get_decimal()
2196        # error if a parameter is passed
2197        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
2198        self.assertIs(decimal_class, pg.Decimal)  # the default setting
2199        pg.set_decimal(int)
2200        try:
2201            r = pg.get_decimal()
2202        finally:
2203            pg.set_decimal(decimal_class)
2204        self.assertIs(r, int)
2205        r = pg.get_decimal()
2206        self.assertIs(r, decimal_class)
2207
2208    def testSetDecimal(self):
2209        decimal_class = pg.get_decimal()
2210        # error if no parameter is passed
2211        self.assertRaises(TypeError, pg.set_decimal)
2212        query = self.c.query
2213        try:
2214            r = query("select 3425::numeric")
2215        except pg.DatabaseError:
2216            self.skipTest('database does not support numeric')
2217        r = r.getresult()[0][0]
2218        self.assertIsInstance(r, decimal_class)
2219        self.assertEqual(r, decimal_class('3425'))
2220        r = query("select 3425::numeric")
2221        pg.set_decimal(int)
2222        try:
2223            r = r.getresult()[0][0]
2224        finally:
2225            pg.set_decimal(decimal_class)
2226        self.assertNotIsInstance(r, decimal_class)
2227        self.assertIsInstance(r, int)
2228        self.assertEqual(r, int(3425))
2229
2230    def testGetBool(self):
2231        use_bool = pg.get_bool()
2232        # error if a parameter is passed
2233        self.assertRaises(TypeError, pg.get_bool, use_bool)
2234        self.assertIsInstance(use_bool, bool)
2235        self.assertIs(use_bool, True)  # the default setting
2236        pg.set_bool(False)
2237        try:
2238            r = pg.get_bool()
2239        finally:
2240            pg.set_bool(use_bool)
2241        self.assertIsInstance(r, bool)
2242        self.assertIs(r, False)
2243        pg.set_bool(True)
2244        try:
2245            r = pg.get_bool()
2246        finally:
2247            pg.set_bool(use_bool)
2248        self.assertIsInstance(r, bool)
2249        self.assertIs(r, True)
2250        pg.set_bool(0)
2251        try:
2252            r = pg.get_bool()
2253        finally:
2254            pg.set_bool(use_bool)
2255        self.assertIsInstance(r, bool)
2256        self.assertIs(r, False)
2257        pg.set_bool(1)
2258        try:
2259            r = pg.get_bool()
2260        finally:
2261            pg.set_bool(use_bool)
2262        self.assertIsInstance(r, bool)
2263        self.assertIs(r, True)
2264
2265    def testSetBool(self):
2266        use_bool = pg.get_bool()
2267        # error if no parameter is passed
2268        self.assertRaises(TypeError, pg.set_bool)
2269        query = self.c.query
2270        try:
2271            r = query("select true::bool")
2272        except pg.ProgrammingError:
2273            self.skipTest('database does not support bool')
2274        r = r.getresult()[0][0]
2275        self.assertIsInstance(r, bool)
2276        self.assertEqual(r, True)
2277        pg.set_bool(False)
2278        try:
2279            r = query("select true::bool").getresult()[0][0]
2280        finally:
2281            pg.set_bool(use_bool)
2282        self.assertIsInstance(r, str)
2283        self.assertIs(r, 't')
2284        pg.set_bool(True)
2285        try:
2286            r = query("select true::bool").getresult()[0][0]
2287        finally:
2288            pg.set_bool(use_bool)
2289        self.assertIsInstance(r, bool)
2290        self.assertIs(r, True)
2291
2292    def testGetByteEscaped(self):
2293        bytea_escaped = pg.get_bytea_escaped()
2294        # error if a parameter is passed
2295        self.assertRaises(TypeError, pg.get_bytea_escaped, bytea_escaped)
2296        self.assertIsInstance(bytea_escaped, bool)
2297        self.assertIs(bytea_escaped, False)  # the default setting
2298        pg.set_bytea_escaped(True)
2299        try:
2300            r = pg.get_bytea_escaped()
2301        finally:
2302            pg.set_bytea_escaped(bytea_escaped)
2303        self.assertIsInstance(r, bool)
2304        self.assertIs(r, True)
2305        pg.set_bytea_escaped(False)
2306        try:
2307            r = pg.get_bytea_escaped()
2308        finally:
2309            pg.set_bytea_escaped(bytea_escaped)
2310        self.assertIsInstance(r, bool)
2311        self.assertIs(r, False)
2312        pg.set_bytea_escaped(1)
2313        try:
2314            r = pg.get_bytea_escaped()
2315        finally:
2316            pg.set_bytea_escaped(bytea_escaped)
2317        self.assertIsInstance(r, bool)
2318        self.assertIs(r, True)
2319        pg.set_bytea_escaped(0)
2320        try:
2321            r = pg.get_bytea_escaped()
2322        finally:
2323            pg.set_bytea_escaped(bytea_escaped)
2324        self.assertIsInstance(r, bool)
2325        self.assertIs(r, False)
2326
2327    def testSetByteaEscaped(self):
2328        bytea_escaped = pg.get_bytea_escaped()
2329        # error if no parameter is passed
2330        self.assertRaises(TypeError, pg.set_bytea_escaped)
2331        query = self.c.query
2332        try:
2333            r = query("select 'data'::bytea")
2334        except pg.ProgrammingError:
2335            self.skipTest('database does not support bytea')
2336        r = r.getresult()[0][0]
2337        self.assertIsInstance(r, bytes)
2338        self.assertEqual(r, b'data')
2339        pg.set_bytea_escaped(True)
2340        try:
2341            r = query("select 'data'::bytea").getresult()[0][0]
2342        finally:
2343            pg.set_bytea_escaped(bytea_escaped)
2344        self.assertIsInstance(r, str)
2345        self.assertEqual(r, '\\x64617461')
2346        pg.set_bytea_escaped(False)
2347        try:
2348            r = query("select 'data'::bytea").getresult()[0][0]
2349        finally:
2350            pg.set_bytea_escaped(bytea_escaped)
2351        self.assertIsInstance(r, bytes)
2352        self.assertEqual(r, b'data')
2353
2354    def testSetRowFactorySize(self):
2355        try:
2356            from functools import lru_cache
2357        except ImportError:  # Python < 3.2
2358            lru_cache = None
2359        queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc']
2360        query = self.c.query
2361        for maxsize in (None, 0, 1, 2, 3, 10, 1024):
2362            pg.set_row_factory_size(maxsize)
2363            for i in range(3):
2364                for q in queries:
2365                    r = query(q).namedresult()[0]
2366                    if q.endswith('abc'):
2367                        self.assertEqual(r, (123,))
2368                        self.assertEqual(r._fields, ('abc',))
2369                    else:
2370                        self.assertEqual(r, (1, 2, 3))
2371                        self.assertEqual(r._fields, ('a', 'b', 'c'))
2372            if lru_cache:
2373                info = pg._row_factory.cache_info()
2374                self.assertEqual(info.maxsize, maxsize)
2375                self.assertEqual(info.hits + info.misses, 6)
2376                self.assertEqual(info.hits,
2377                    0 if maxsize is not None and maxsize < 2 else 4)
2378
2379
2380class TestStandaloneEscapeFunctions(unittest.TestCase):
2381    """Test pg escape functions.
2382
2383    The libpq interface memorizes some parameters of the last opened
2384    connection that influence the result of these functions.  Therefore
2385    we need to open a connection with fixed parameters prior to testing
2386    in order to ensure that the tests always run under the same conditions.
2387    That's why these tests are included in this test module.
2388    """
2389
2390    cls_set_up = False
2391
2392    @classmethod
2393    def setUpClass(cls):
2394        db = connect()
2395        query = db.query
2396        query('set client_encoding=sql_ascii')
2397        query('set standard_conforming_strings=off')
2398        try:
2399            query('set bytea_output=escape')
2400        except pg.ProgrammingError:
2401            if db.server_version >= 90000:
2402                raise  # ignore for older server versions
2403        db.close()
2404        cls.cls_set_up = True
2405
2406    def testEscapeString(self):
2407        self.assertTrue(self.cls_set_up)
2408        f = pg.escape_string
2409        r = f(b'plain')
2410        self.assertIsInstance(r, bytes)
2411        self.assertEqual(r, b'plain')
2412        r = f(u'plain')
2413        self.assertIsInstance(r, unicode)
2414        self.assertEqual(r, u'plain')
2415        r = f(u"das is' kÀse".encode('utf-8'))
2416        self.assertIsInstance(r, bytes)
2417        self.assertEqual(r, u"das is'' kÀse".encode('utf-8'))
2418        r = f(u"that's cheesy")
2419        self.assertIsInstance(r, unicode)
2420        self.assertEqual(r, u"that''s cheesy")
2421        r = f(r"It's bad to have a \ inside.")
2422        self.assertEqual(r, r"It''s bad to have a \\ inside.")
2423
2424    def testEscapeBytea(self):
2425        self.assertTrue(self.cls_set_up)
2426        f = pg.escape_bytea
2427        r = f(b'plain')
2428        self.assertIsInstance(r, bytes)
2429        self.assertEqual(r, b'plain')
2430        r = f(u'plain')
2431        self.assertIsInstance(r, unicode)
2432        self.assertEqual(r, u'plain')
2433        r = f(u"das is' kÀse".encode('utf-8'))
2434        self.assertIsInstance(r, bytes)
2435        self.assertEqual(r, b"das is'' k\\\\303\\\\244se")
2436        r = f(u"that's cheesy")
2437        self.assertIsInstance(r, unicode)
2438        self.assertEqual(r, u"that''s cheesy")
2439        r = f(b'O\x00ps\xff!')
2440        self.assertEqual(r, b'O\\\\000ps\\\\377!')
2441
2442
2443if __name__ == '__main__':
2444    unittest.main()
Note: See TracBrowser for help on using the repository browser.