source: trunk/tests/test_dbapi20_copy.py

Last change on this file was 997, checked in by cito, 6 months ago

Some more IDE hints

  • Property svn:keywords set to Id
File size: 19.8 KB
RevLine 
[969]1#!/usr/bin/python
[694]2# -*- coding: utf-8 -*-
3
4"""Test the modern PyGreSQL interface.
5
6Sub-tests for the copy methods.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17
[965]18try:
19    from collections.abc import Iterable
20except ImportError:  # Python < 3.3
21    from collections import Iterable
[694]22
23import pgdb  # the module under test
24
25# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
26# get our information from that.  Otherwise we use the defaults.
27# The current user must have create schema privilege on the database.
28dbname = 'unittest'
29dbhost = None
30dbport = 5432
31
32try:
33    from .LOCAL_PyGreSQL import *
34except (ImportError, ValueError):
35    try:
36        from LOCAL_PyGreSQL import *
37    except ImportError:
38        pass
39
[997]40try:  # noinspection PyUnresolvedReferences
[694]41    unicode
42except NameError:  # Python >= 3.0
43    unicode = str
44
45
46class InputStream:
47
48    def __init__(self, data):
49        if isinstance(data, unicode):
50            data = data.encode('utf-8')
51        self.data = data or b''
52        self.sizes = []
53
54    def __str__(self):
55        data = self.data
[772]56        if str is unicode:  # Python >= 3.0
[694]57            data = data.decode('utf-8')
58        return data
59
60    def __len__(self):
61        return len(self.data)
62
63    def read(self, size=None):
64        if size is None:
65            output, data = self.data, b''
66        else:
67            output, data = self.data[:size], self.data[size:]
68        self.data = data
69        self.sizes.append(size)
70        return output
71
72
73class OutputStream:
74
75    def __init__(self):
76        self.data = b''
77        self.sizes = []
78
79    def __str__(self):
80        data = self.data
[772]81        if str is unicode:  # Python >= 3.0
[694]82            data = data.decode('utf-8')
83        return data
84
85    def __len__(self):
86        return len(self.data)
87
88    def write(self, data):
89        if isinstance(data, unicode):
90            data = data.encode('utf-8')
91        self.data += data
92        self.sizes.append(len(data))
93
94
95class TestStreams(unittest.TestCase):
96
97    def test_input(self):
98        stream = InputStream('Hello, Wörld!')
99        self.assertIsInstance(stream.data, bytes)
100        self.assertEqual(stream.data, b'Hello, W\xc3\xb6rld!')
101        self.assertIsInstance(str(stream), str)
102        self.assertEqual(str(stream), 'Hello, Wörld!')
103        self.assertEqual(len(stream), 14)
104        self.assertEqual(stream.read(3), b'Hel')
105        self.assertEqual(stream.read(2), b'lo')
106        self.assertEqual(stream.read(1), b',')
107        self.assertEqual(stream.read(1), b' ')
108        self.assertEqual(stream.read(), b'W\xc3\xb6rld!')
109        self.assertEqual(stream.read(), b'')
110        self.assertEqual(len(stream), 0)
111        self.assertEqual(stream.sizes, [3, 2, 1, 1, None, None])
112
113    def test_output(self):
114        stream = OutputStream()
115        self.assertEqual(len(stream), 0)
116        for chunk in 'Hel', 'lo', ',', ' ', 'Wörld!':
117            stream.write(chunk)
118        self.assertIsInstance(stream.data, bytes)
119        self.assertEqual(stream.data, b'Hello, W\xc3\xb6rld!')
120        self.assertIsInstance(str(stream), str)
121        self.assertEqual(str(stream), 'Hello, Wörld!')
122        self.assertEqual(len(stream), 14)
123        self.assertEqual(stream.sizes, [3, 2, 1, 1, 7])
124
125
126class TestCopy(unittest.TestCase):
127
[770]128    cls_set_up = False
129
[694]130    @staticmethod
131    def connect():
132        return pgdb.connect(database=dbname,
133            host='%s:%d' % (dbhost or '', dbport or -1))
134
135    @classmethod
136    def setUpClass(cls):
137        con = cls.connect()
138        cur = con.cursor()
139        cur.execute("set client_min_messages=warning")
140        cur.execute("drop table if exists copytest cascade")
141        cur.execute("create table copytest ("
142            "id smallint primary key, name varchar(64))")
143        cur.close()
144        con.commit()
[823]145        cur = con.cursor()
146        try:
147            cur.execute("set client_encoding=utf8")
148            cur.execute("select 'Plácido and José'").fetchone()
149        except (pgdb.DataError, pgdb.NotSupportedError):
150            cls.data[1] = (1941, 'Plaacido Domingo')
151            cls.data[2] = (1946, 'Josee Carreras')
152            cls.can_encode = False
153        cur.close()
[694]154        con.close()
[770]155        cls.cls_set_up = True
[694]156
157    @classmethod
158    def tearDownClass(cls):
159        con = cls.connect()
160        cur = con.cursor()
161        cur.execute("set client_min_messages=warning")
162        cur.execute("drop table if exists copytest cascade")
163        con.commit()
164        con.close()
165
166    def setUp(self):
[770]167        self.assertTrue(self.cls_set_up)
[694]168        self.con = self.connect()
169        self.cursor = self.con.cursor()
170        self.cursor.execute("set client_encoding=utf8")
171
172    def tearDown(self):
173        try:
174            self.cursor.close()
175        except Exception:
176            pass
177        try:
178            self.con.rollback()
179        except Exception:
180            pass
181        try:
182            self.con.close()
183        except Exception:
184            pass
185
186    data = [(1935, 'Luciano Pavarotti'),
187            (1941, 'Plácido Domingo'),
188            (1946, 'José Carreras')]
189
[823]190    can_encode = True
191
[694]192    @property
193    def data_text(self):
194        return ''.join('%d\t%s\n' % row for row in self.data)
195
196    @property
197    def data_csv(self):
198        return ''.join('%d,%s\n' % row for row in self.data)
199
200    def truncate_table(self):
201        self.cursor.execute("truncate table copytest")
202
203    @property
204    def table_data(self):
205        self.cursor.execute("select * from copytest")
206        return self.cursor.fetchall()
207
208    def check_table(self):
209        self.assertEqual(self.table_data, self.data)
210
[697]211    def check_rowcount(self, number=len(data)):
212        self.assertEqual(self.cursor.rowcount, number)
[694]213
[697]214
[694]215class TestCopyFrom(TestCopy):
216    """Test the copy_from method."""
217
218    def tearDown(self):
219        super(TestCopyFrom, self).tearDown()
220        self.setUp()
221        self.truncate_table()
222        super(TestCopyFrom, self).tearDown()
223
224    def copy_from(self, stream, **options):
225        return self.cursor.copy_from(stream, 'copytest', **options)
226
227    @property
228    def data_file(self):
229        return InputStream(self.data_text)
230
231    def test_bad_params(self):
232        call = self.cursor.copy_from
233        call('0\t', 'copytest'), self.cursor
234        call('1\t', 'copytest',
235             format='text', sep='\t', null='', columns=['id', 'name'])
236        self.assertRaises(TypeError, call)
[772]237        self.assertRaises(TypeError, call, None)
238        self.assertRaises(TypeError, call, None, None)
[694]239        self.assertRaises(TypeError, call, '0\t')
[772]240        self.assertRaises(TypeError, call, '0\t', None)
[694]241        self.assertRaises(TypeError, call, '0\t', 42)
242        self.assertRaises(TypeError, call, '0\t', ['copytest'])
243        self.assertRaises(TypeError, call, '0\t', 'copytest', format=42)
244        self.assertRaises(ValueError, call, '0\t', 'copytest', format='bad')
245        self.assertRaises(TypeError, call, '0\t', 'copytest', sep=42)
246        self.assertRaises(ValueError, call, '0\t', 'copytest', sep='bad')
247        self.assertRaises(TypeError, call, '0\t', 'copytest', null=42)
248        self.assertRaises(ValueError, call, '0\t', 'copytest', size='bad')
249        self.assertRaises(TypeError, call, '0\t', 'copytest', columns=42)
[772]250        self.assertRaises(ValueError, call, b'', 'copytest',
251            format='binary', sep=',')
[694]252
253    def test_input_string(self):
254        ret = self.copy_from('42\tHello, world!')
255        self.assertIs(ret, self.cursor)
256        self.assertEqual(self.table_data, [(42, 'Hello, world!')])
[697]257        self.check_rowcount(1)
[694]258
259    def test_input_string_with_newline(self):
260        self.copy_from('42\tHello, world!\n')
261        self.assertEqual(self.table_data, [(42, 'Hello, world!')])
[697]262        self.check_rowcount(1)
[694]263
264    def test_input_string_multiple_rows(self):
265        ret = self.copy_from(self.data_text)
266        self.assertIs(ret, self.cursor)
267        self.check_table()
[697]268        self.check_rowcount()
[694]269
[772]270    if str is unicode:  # Python >= 3.0
[694]271
272        def test_input_bytes(self):
273            self.copy_from(b'42\tHello, world!')
274            self.assertEqual(self.table_data, [(42, 'Hello, world!')])
275            self.truncate_table()
276            self.copy_from(self.data_text.encode('utf-8'))
277            self.check_table()
278
[772]279    else:  # Python < 3.0
[694]280
281        def test_input_unicode(self):
[823]282            if not self.can_encode:
283                self.skipTest('database does not support utf8')
[694]284            self.copy_from(u'43\tWÃŒrstel, KÀse!')
285            self.assertEqual(self.table_data, [(43, 'WÃŒrstel, KÀse!')])
286            self.truncate_table()
287            self.copy_from(self.data_text.decode('utf-8'))
288            self.check_table()
289
290    def test_input_iterable(self):
291        self.copy_from(self.data_text.splitlines())
292        self.check_table()
[697]293        self.check_rowcount()
[694]294
[772]295    def test_input_iterable_invalid(self):
296        self.assertRaises(IOError, self.copy_from, [None])
297
[694]298    def test_input_iterable_with_newlines(self):
299        self.copy_from('%s\n' % row for row in self.data_text.splitlines())
300        self.check_table()
301
[772]302    if str is unicode:  # Python >= 3.0
303
304        def test_input_iterable_bytes(self):
305            self.copy_from(row.encode('utf-8')
306                for row in self.data_text.splitlines())
307            self.check_table()
308
[694]309    def test_sep(self):
310        stream = ('%d-%s' % row for row in self.data)
311        self.copy_from(stream, sep='-')
312        self.check_table()
313
314    def test_null(self):
315        self.copy_from('0\t\\N')
316        self.assertEqual(self.table_data, [(0, None)])
317        self.assertIsNone(self.table_data[0][1])
318        self.truncate_table()
319        self.copy_from('1\tNix')
320        self.assertEqual(self.table_data, [(1, 'Nix')])
321        self.assertIsNotNone(self.table_data[0][1])
322        self.truncate_table()
323        self.copy_from('2\tNix', null='Nix')
324        self.assertEqual(self.table_data, [(2, None)])
325        self.assertIsNone(self.table_data[0][1])
326        self.truncate_table()
327        self.copy_from('3\t')
328        self.assertEqual(self.table_data, [(3, '')])
329        self.assertIsNotNone(self.table_data[0][1])
330        self.truncate_table()
331        self.copy_from('4\t', null='')
332        self.assertEqual(self.table_data, [(4, None)])
333        self.assertIsNone(self.table_data[0][1])
334
335    def test_columns(self):
336        self.copy_from('1', columns='id')
337        self.copy_from('2', columns=['id'])
338        self.copy_from('3\tThree')
339        self.copy_from('4\tFour', columns='id, name')
340        self.copy_from('5\tFive', columns=['id', 'name'])
341        self.assertEqual(self.table_data, [
342            (1, None), (2, None), (3, 'Three'), (4, 'Four'), (5, 'Five')])
[697]343        self.check_rowcount(5)
[694]344        self.assertRaises(pgdb.ProgrammingError, self.copy_from,
345            '6\t42', columns=['id', 'age'])
[697]346        self.check_rowcount(-1)
[694]347
348    def test_csv(self):
349        self.copy_from(self.data_csv, format='csv')
350        self.check_table()
351
352    def test_csv_with_sep(self):
353        stream = ('%d;"%s"\n' % row for row in self.data)
354        self.copy_from(stream, format='csv', sep=';')
355        self.check_table()
[697]356        self.check_rowcount()
[694]357
358    def test_binary(self):
359        self.assertRaises(IOError, self.copy_from,
360            b'NOPGCOPY\n', format='binary')
[697]361        self.check_rowcount(-1)
[694]362
363    def test_binary_with_sep(self):
364        self.assertRaises(ValueError, self.copy_from,
365            '', format='binary', sep='\t')
366
367    def test_binary_with_unicode(self):
368        self.assertRaises(ValueError, self.copy_from, u'', format='binary')
369
370    def test_query(self):
371        self.assertRaises(ValueError, self.cursor.copy_from, '', "select null")
372
373    def test_file(self):
374        stream = self.data_file
375        ret = self.copy_from(stream)
376        self.assertIs(ret, self.cursor)
377        self.check_table()
378        self.assertEqual(len(stream), 0)
379        self.assertEqual(stream.sizes, [8192])
[697]380        self.check_rowcount()
[694]381
382    def test_size_positive(self):
383        stream = self.data_file
384        size = 7
385        num_chunks = (len(stream) + size - 1) // size
386        self.copy_from(stream, size=size)
387        self.check_table()
388        self.assertEqual(len(stream), 0)
389        self.assertEqual(stream.sizes, [size] * num_chunks)
[697]390        self.check_rowcount()
[694]391
392    def test_size_negative(self):
393        stream = self.data_file
394        self.copy_from(stream, size=-1)
395        self.check_table()
396        self.assertEqual(len(stream), 0)
397        self.assertEqual(stream.sizes, [None])
[697]398        self.check_rowcount()
[694]399
[772]400    def test_size_invalid(self):
401        self.assertRaises(TypeError,
402            self.copy_from, self.data_file, size='invalid')
[694]403
[772]404
[694]405class TestCopyTo(TestCopy):
406    """Test the copy_to method."""
407
408    @classmethod
409    def setUpClass(cls):
410        super(TestCopyTo, cls).setUpClass()
411        con = cls.connect()
412        cur = con.cursor()
[823]413        cur.execute("set client_encoding=utf8")
[694]414        cur.execute("insert into copytest values (%d, %s)", cls.data)
415        cur.close()
416        con.commit()
417        con.close()
418
419    def copy_to(self, stream=None, **options):
420        return self.cursor.copy_to(stream, 'copytest', **options)
421
422    @property
423    def data_file(self):
424        return OutputStream()
425
426    def test_bad_params(self):
427        call = self.cursor.copy_to
428        call(None, 'copytest')
429        call(None, 'copytest',
430             format='text', sep='\t', null='', columns=['id', 'name'])
431        self.assertRaises(TypeError, call)
432        self.assertRaises(TypeError, call, None)
433        self.assertRaises(TypeError, call, None, 42)
434        self.assertRaises(TypeError, call, None, ['copytest'])
435        self.assertRaises(TypeError, call, 'bad', 'copytest')
436        self.assertRaises(TypeError, call, None, 'copytest', format=42)
437        self.assertRaises(ValueError, call, None, 'copytest', format='bad')
438        self.assertRaises(TypeError, call, None, 'copytest', sep=42)
439        self.assertRaises(ValueError, call, None, 'copytest', sep='bad')
440        self.assertRaises(TypeError, call, None, 'copytest', null=42)
441        self.assertRaises(TypeError, call, None, 'copytest', decode='bad')
442        self.assertRaises(TypeError, call, None, 'copytest', columns=42)
443
444    def test_generator(self):
445        ret = self.copy_to()
446        self.assertIsInstance(ret, Iterable)
447        rows = list(ret)
448        self.assertEqual(len(rows), 3)
449        rows = ''.join(rows)
450        self.assertIsInstance(rows, str)
451        self.assertEqual(rows, self.data_text)
[697]452        self.check_rowcount()
[694]453
[772]454    if str is unicode:  # Python >= 3.0
[694]455
456        def test_generator_bytes(self):
457            ret = self.copy_to(decode=False)
458            self.assertIsInstance(ret, Iterable)
459            rows = list(ret)
460            self.assertEqual(len(rows), 3)
461            rows = b''.join(rows)
462            self.assertIsInstance(rows, bytes)
463            self.assertEqual(rows, self.data_text.encode('utf-8'))
464
[772]465    else:  # Python < 3.0
[694]466
467        def test_generator_unicode(self):
468            ret = self.copy_to(decode=True)
469            self.assertIsInstance(ret, Iterable)
470            rows = list(ret)
471            self.assertEqual(len(rows), 3)
472            rows = ''.join(rows)
473            self.assertIsInstance(rows, unicode)
474            self.assertEqual(rows, self.data_text.decode('utf-8'))
475
[697]476    def test_rowcount_increment(self):
477        ret = self.copy_to()
478        self.assertIsInstance(ret, Iterable)
479        for n, row in enumerate(ret):
480            self.check_rowcount(n + 1)
481
[694]482    def test_decode(self):
483        ret_raw = b''.join(self.copy_to(decode=False))
484        ret_decoded = ''.join(self.copy_to(decode=True))
485        self.assertIsInstance(ret_raw, bytes)
486        self.assertIsInstance(ret_decoded, unicode)
487        self.assertEqual(ret_decoded, ret_raw.decode('utf-8'))
[697]488        self.check_rowcount()
[694]489
490    def test_sep(self):
491        ret = list(self.copy_to(sep='-'))
492        self.assertEqual(ret, ['%d-%s\n' % row for row in self.data])
493
494    def test_null(self):
495        data = ['%d\t%s\n' % row for row in self.data]
496        self.cursor.execute('insert into copytest values(4, null)')
497        try:
498            ret = list(self.copy_to())
499            self.assertEqual(ret, data + ['4\t\\N\n'])
500            ret = list(self.copy_to(null='Nix'))
501            self.assertEqual(ret, data + ['4\tNix\n'])
502            ret = list(self.copy_to(null=''))
503            self.assertEqual(ret, data + ['4\t\n'])
504        finally:
505            self.cursor.execute('delete from copytest where id=4')
506
507    def test_columns(self):
508        data_id = ''.join('%d\n' % row[0] for row in self.data)
509        data_name = ''.join('%s\n' % row[1] for row in self.data)
510        ret = ''.join(self.copy_to(columns='id'))
511        self.assertEqual(ret, data_id)
512        ret = ''.join(self.copy_to(columns=['id']))
513        self.assertEqual(ret, data_id)
514        ret = ''.join(self.copy_to(columns='name'))
515        self.assertEqual(ret, data_name)
516        ret = ''.join(self.copy_to(columns=['name']))
517        self.assertEqual(ret, data_name)
518        ret = ''.join(self.copy_to(columns='id, name'))
519        self.assertEqual(ret, self.data_text)
520        ret = ''.join(self.copy_to(columns=['id', 'name']))
521        self.assertEqual(ret, self.data_text)
522        self.assertRaises(pgdb.ProgrammingError, self.copy_to,
523            columns=['id', 'age'])
524
525    def test_csv(self):
526        ret = self.copy_to(format='csv')
527        self.assertIsInstance(ret, Iterable)
528        rows = list(ret)
529        self.assertEqual(len(rows), 3)
530        rows = ''.join(rows)
531        self.assertIsInstance(rows, str)
532        self.assertEqual(rows, self.data_csv)
[697]533        self.check_rowcount(3)
[694]534
535    def test_csv_with_sep(self):
536        rows = ''.join(self.copy_to(format='csv', sep=';'))
537        self.assertEqual(rows, self.data_csv.replace(',', ';'))
538
539    def test_binary(self):
540        ret = self.copy_to(format='binary')
541        self.assertIsInstance(ret, Iterable)
542        for row in ret:
543            self.assertTrue(row.startswith(b'PGCOPY\n\377\r\n\0'))
544            break
[697]545        self.check_rowcount(1)
[694]546
547    def test_binary_with_sep(self):
548        self.assertRaises(ValueError, self.copy_to, format='binary', sep='\t')
549
550    def test_binary_with_unicode(self):
551        self.assertRaises(ValueError, self.copy_to,
552            format='binary', decode=True)
553
554    def test_query(self):
[772]555        self.assertRaises(ValueError, self.cursor.copy_to, None,
556            "select name from copytest", columns='noname')
[694]557        ret = self.cursor.copy_to(None,
558            "select name||'!' from copytest where id=1941")
559        self.assertIsInstance(ret, Iterable)
560        rows = list(ret)
561        self.assertEqual(len(rows), 1)
562        self.assertIsInstance(rows[0], str)
563        self.assertEqual(rows[0], '%s!\n' % self.data[1][1])
[697]564        self.check_rowcount(1)
[694]565
566    def test_file(self):
567        stream = self.data_file
568        ret = self.copy_to(stream)
569        self.assertIs(ret, self.cursor)
570        self.assertEqual(str(stream), self.data_text)
571        data = self.data_text
[772]572        if str is unicode:  # Python >= 3.0
[694]573            data = data.encode('utf-8')
574        sizes = [len(row) + 1 for row in data.splitlines()]
575        self.assertEqual(stream.sizes, sizes)
[697]576        self.check_rowcount()
[694]577
578
579class TestBinary(TestCopy):
580    """Test the copy_from and copy_to methods with binary data."""
581
582    def test_round_trip(self):
583        # fill table from textual data
584        self.cursor.copy_from(self.data_text, 'copytest', format='text')
585        self.check_table()
[697]586        self.check_rowcount()
[694]587        # get data back in binary format
588        ret = self.cursor.copy_to(None, 'copytest', format='binary')
589        self.assertIsInstance(ret, Iterable)
590        data_binary = b''.join(ret)
591        self.assertTrue(data_binary.startswith(b'PGCOPY\n\377\r\n\0'))
[697]592        self.check_rowcount()
[694]593        self.truncate_table()
594        # fill table from binary data
595        self.cursor.copy_from(data_binary, 'copytest', format='binary')
596        self.check_table()
[697]597        self.check_rowcount()
[694]598
599
600if __name__ == '__main__':
601    unittest.main()
Note: See TracBrowser for help on using the repository browser.