source: trunk/module/tests/test_dbapi20_copy.py @ 694

Last change on this file since 694 was 694, checked in by cito, 4 years ago

Add tests for copy_from() and copy_to()

  • Property svn:keywords set to Id
File size: 17.2 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the modern PyGreSQL interface.
5
6Sub-tests for the copy methods.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18
19from collections import Iterable
20
21import pgdb  # the module under test
22
23# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
24# get our information from that.  Otherwise we use the defaults.
25# The current user must have create schema privilege on the database.
26dbname = 'unittest'
27dbhost = None
28dbport = 5432
29
30try:
31    from .LOCAL_PyGreSQL import *
32except (ImportError, ValueError):
33    try:
34        from LOCAL_PyGreSQL import *
35    except ImportError:
36        pass
37
38try:
39    unicode
40except NameError:  # Python >= 3.0
41    unicode = str
42
43
44class InputStream:
45
46    def __init__(self, data):
47        if isinstance(data, unicode):
48            data = data.encode('utf-8')
49        self.data = data or b''
50        self.sizes = []
51
52    def __str__(self):
53        data = self.data
54        if str is unicode:
55            data = data.decode('utf-8')
56        return data
57
58    def __len__(self):
59        return len(self.data)
60
61    def read(self, size=None):
62        if size is None:
63            output, data = self.data, b''
64        else:
65            output, data = self.data[:size], self.data[size:]
66        self.data = data
67        self.sizes.append(size)
68        return output
69
70
71class OutputStream:
72
73    def __init__(self):
74        self.data = b''
75        self.sizes = []
76
77    def __str__(self):
78        data = self.data
79        if str is unicode:
80            data = data.decode('utf-8')
81        return data
82
83    def __len__(self):
84        return len(self.data)
85
86    def write(self, data):
87        if isinstance(data, unicode):
88            data = data.encode('utf-8')
89        self.data += data
90        self.sizes.append(len(data))
91
92
93class TestStreams(unittest.TestCase):
94
95    def test_input(self):
96        stream = InputStream('Hello, Wörld!')
97        self.assertIsInstance(stream.data, bytes)
98        self.assertEqual(stream.data, b'Hello, W\xc3\xb6rld!')
99        self.assertIsInstance(str(stream), str)
100        self.assertEqual(str(stream), 'Hello, Wörld!')
101        self.assertEqual(len(stream), 14)
102        self.assertEqual(stream.read(3), b'Hel')
103        self.assertEqual(stream.read(2), b'lo')
104        self.assertEqual(stream.read(1), b',')
105        self.assertEqual(stream.read(1), b' ')
106        self.assertEqual(stream.read(), b'W\xc3\xb6rld!')
107        self.assertEqual(stream.read(), b'')
108        self.assertEqual(len(stream), 0)
109        self.assertEqual(stream.sizes, [3, 2, 1, 1, None, None])
110
111    def test_output(self):
112        stream = OutputStream()
113        self.assertEqual(len(stream), 0)
114        for chunk in 'Hel', 'lo', ',', ' ', 'Wörld!':
115            stream.write(chunk)
116        self.assertIsInstance(stream.data, bytes)
117        self.assertEqual(stream.data, b'Hello, W\xc3\xb6rld!')
118        self.assertIsInstance(str(stream), str)
119        self.assertEqual(str(stream), 'Hello, Wörld!')
120        self.assertEqual(len(stream), 14)
121        self.assertEqual(stream.sizes, [3, 2, 1, 1, 7])
122
123
124class TestCopy(unittest.TestCase):
125
126    @staticmethod
127    def connect():
128        return pgdb.connect(database=dbname,
129            host='%s:%d' % (dbhost or '', dbport or -1))
130
131    @classmethod
132    def setUpClass(cls):
133        con = cls.connect()
134        cur = con.cursor()
135        cur.execute("set client_min_messages=warning")
136        cur.execute("drop table if exists copytest cascade")
137        cur.execute("create table copytest ("
138            "id smallint primary key, name varchar(64))")
139        cur.close()
140        con.commit()
141        con.close()
142
143    @classmethod
144    def tearDownClass(cls):
145        con = cls.connect()
146        cur = con.cursor()
147        cur.execute("set client_min_messages=warning")
148        cur.execute("drop table if exists copytest cascade")
149        con.commit()
150        con.close()
151
152    def setUp(self):
153        self.con = self.connect()
154        self.cursor = self.con.cursor()
155        self.cursor.execute("set client_encoding=utf8")
156
157    def tearDown(self):
158        try:
159            self.cursor.close()
160        except Exception:
161            pass
162        try:
163            self.con.rollback()
164        except Exception:
165            pass
166        try:
167            self.con.close()
168        except Exception:
169            pass
170
171    data = [(1935, 'Luciano Pavarotti'),
172            (1941, 'Plácido Domingo'),
173            (1946, 'José Carreras')]
174
175    @property
176    def data_text(self):
177        return ''.join('%d\t%s\n' % row for row in self.data)
178
179    @property
180    def data_csv(self):
181        return ''.join('%d,%s\n' % row for row in self.data)
182
183    def truncate_table(self):
184        self.cursor.execute("truncate table copytest")
185
186    @property
187    def table_data(self):
188        self.cursor.execute("select * from copytest")
189        return self.cursor.fetchall()
190
191    def check_table(self):
192        self.assertEqual(self.table_data, self.data)
193
194
195class TestCopyFrom(TestCopy):
196    """Test the copy_from method."""
197
198    def tearDown(self):
199        super(TestCopyFrom, self).tearDown()
200        self.setUp()
201        self.truncate_table()
202        super(TestCopyFrom, self).tearDown()
203
204    def copy_from(self, stream, **options):
205        return self.cursor.copy_from(stream, 'copytest', **options)
206
207    @property
208    def data_file(self):
209        return InputStream(self.data_text)
210
211    def test_bad_params(self):
212        call = self.cursor.copy_from
213        call('0\t', 'copytest'), self.cursor
214        call('1\t', 'copytest',
215             format='text', sep='\t', null='', columns=['id', 'name'])
216        self.assertRaises(TypeError, call)
217        self.assertRaises(TypeError, call, '0\t')
218        self.assertRaises(TypeError, call, '0\t', 42)
219        self.assertRaises(TypeError, call, '0\t', ['copytest'])
220        self.assertRaises(TypeError, call, '0\t', 'copytest', format=42)
221        self.assertRaises(ValueError, call, '0\t', 'copytest', format='bad')
222        self.assertRaises(TypeError, call, '0\t', 'copytest', sep=42)
223        self.assertRaises(ValueError, call, '0\t', 'copytest', sep='bad')
224        self.assertRaises(TypeError, call, '0\t', 'copytest', null=42)
225        self.assertRaises(ValueError, call, '0\t', 'copytest', size='bad')
226        self.assertRaises(TypeError, call, '0\t', 'copytest', columns=42)
227
228    def test_input_string(self):
229        ret = self.copy_from('42\tHello, world!')
230        self.assertIs(ret, self.cursor)
231        self.assertEqual(self.table_data, [(42, 'Hello, world!')])
232
233    def test_input_string_with_newline(self):
234        self.copy_from('42\tHello, world!\n')
235        self.assertEqual(self.table_data, [(42, 'Hello, world!')])
236
237    def test_input_string_multiple_rows(self):
238        ret = self.copy_from(self.data_text)
239        self.assertIs(ret, self.cursor)
240        self.check_table()
241
242    if str is unicode:
243
244        def test_input_bytes(self):
245            self.copy_from(b'42\tHello, world!')
246            self.assertEqual(self.table_data, [(42, 'Hello, world!')])
247            self.truncate_table()
248            self.copy_from(self.data_text.encode('utf-8'))
249            self.check_table()
250
251    if str is not unicode:
252
253        def test_input_unicode(self):
254            self.copy_from(u'43\tWÃŒrstel, KÀse!')
255            self.assertEqual(self.table_data, [(43, 'WÃŒrstel, KÀse!')])
256            self.truncate_table()
257            self.copy_from(self.data_text.decode('utf-8'))
258            self.check_table()
259
260    def test_input_iterable(self):
261        self.copy_from(self.data_text.splitlines())
262        self.check_table()
263
264    def test_input_iterable_with_newlines(self):
265        self.copy_from('%s\n' % row for row in self.data_text.splitlines())
266        self.check_table()
267
268    def test_sep(self):
269        stream = ('%d-%s' % row for row in self.data)
270        self.copy_from(stream, sep='-')
271        self.check_table()
272
273    def test_null(self):
274        self.copy_from('0\t\\N')
275        self.assertEqual(self.table_data, [(0, None)])
276        self.assertIsNone(self.table_data[0][1])
277        self.truncate_table()
278        self.copy_from('1\tNix')
279        self.assertEqual(self.table_data, [(1, 'Nix')])
280        self.assertIsNotNone(self.table_data[0][1])
281        self.truncate_table()
282        self.copy_from('2\tNix', null='Nix')
283        self.assertEqual(self.table_data, [(2, None)])
284        self.assertIsNone(self.table_data[0][1])
285        self.truncate_table()
286        self.copy_from('3\t')
287        self.assertEqual(self.table_data, [(3, '')])
288        self.assertIsNotNone(self.table_data[0][1])
289        self.truncate_table()
290        self.copy_from('4\t', null='')
291        self.assertEqual(self.table_data, [(4, None)])
292        self.assertIsNone(self.table_data[0][1])
293
294    def test_columns(self):
295        self.copy_from('1', columns='id')
296        self.copy_from('2', columns=['id'])
297        self.copy_from('3\tThree')
298        self.copy_from('4\tFour', columns='id, name')
299        self.copy_from('5\tFive', columns=['id', 'name'])
300        self.assertEqual(self.table_data, [
301            (1, None), (2, None), (3, 'Three'), (4, 'Four'), (5, 'Five')])
302        self.assertRaises(pgdb.ProgrammingError, self.copy_from,
303            '6\t42', columns=['id', 'age'])
304
305    def test_csv(self):
306        self.copy_from(self.data_csv, format='csv')
307        self.check_table()
308
309    def test_csv_with_sep(self):
310        stream = ('%d;"%s"\n' % row for row in self.data)
311        self.copy_from(stream, format='csv', sep=';')
312        self.check_table()
313
314    def test_binary(self):
315        self.assertRaises(IOError, self.copy_from,
316            b'NOPGCOPY\n', format='binary')
317
318    def test_binary_with_sep(self):
319        self.assertRaises(ValueError, self.copy_from,
320            '', format='binary', sep='\t')
321
322    def test_binary_with_unicode(self):
323        self.assertRaises(ValueError, self.copy_from, u'', format='binary')
324
325    def test_query(self):
326        self.assertRaises(ValueError, self.cursor.copy_from, '', "select null")
327
328    def test_file(self):
329        stream = self.data_file
330        ret = self.copy_from(stream)
331        self.assertIs(ret, self.cursor)
332        self.check_table()
333        self.assertEqual(len(stream), 0)
334        self.assertEqual(stream.sizes, [8192])
335
336    def test_size_positive(self):
337        stream = self.data_file
338        size = 7
339        num_chunks = (len(stream) + size - 1) // size
340        self.copy_from(stream, size=size)
341        self.check_table()
342        self.assertEqual(len(stream), 0)
343        self.assertEqual(stream.sizes, [size] * num_chunks)
344
345    def test_size_negative(self):
346        stream = self.data_file
347        self.copy_from(stream, size=-1)
348        self.check_table()
349        self.assertEqual(len(stream), 0)
350        self.assertEqual(stream.sizes, [None])
351
352
353class TestCopyTo(TestCopy):
354    """Test the copy_to method."""
355
356    @classmethod
357    def setUpClass(cls):
358        super(TestCopyTo, cls).setUpClass()
359        con = cls.connect()
360        cur = con.cursor()
361        cur.execute("insert into copytest values (%d, %s)", cls.data)
362        cur.close()
363        con.commit()
364        con.close()
365
366    def copy_to(self, stream=None, **options):
367        return self.cursor.copy_to(stream, 'copytest', **options)
368
369    @property
370    def data_file(self):
371        return OutputStream()
372
373    def test_bad_params(self):
374        call = self.cursor.copy_to
375        call(None, 'copytest')
376        call(None, 'copytest',
377             format='text', sep='\t', null='', columns=['id', 'name'])
378        self.assertRaises(TypeError, call)
379        self.assertRaises(TypeError, call, None)
380        self.assertRaises(TypeError, call, None, 42)
381        self.assertRaises(TypeError, call, None, ['copytest'])
382        self.assertRaises(TypeError, call, 'bad', 'copytest')
383        self.assertRaises(TypeError, call, None, 'copytest', format=42)
384        self.assertRaises(ValueError, call, None, 'copytest', format='bad')
385        self.assertRaises(TypeError, call, None, 'copytest', sep=42)
386        self.assertRaises(ValueError, call, None, 'copytest', sep='bad')
387        self.assertRaises(TypeError, call, None, 'copytest', null=42)
388        self.assertRaises(TypeError, call, None, 'copytest', decode='bad')
389        self.assertRaises(TypeError, call, None, 'copytest', columns=42)
390
391    def test_generator(self):
392        ret = self.copy_to()
393        self.assertIsInstance(ret, Iterable)
394        rows = list(ret)
395        self.assertEqual(len(rows), 3)
396        rows = ''.join(rows)
397        self.assertIsInstance(rows, str)
398        self.assertEqual(rows, self.data_text)
399
400    if str is unicode:
401
402        def test_generator_bytes(self):
403            ret = self.copy_to(decode=False)
404            self.assertIsInstance(ret, Iterable)
405            rows = list(ret)
406            self.assertEqual(len(rows), 3)
407            rows = b''.join(rows)
408            self.assertIsInstance(rows, bytes)
409            self.assertEqual(rows, self.data_text.encode('utf-8'))
410
411    if str is not unicode:
412
413        def test_generator_unicode(self):
414            ret = self.copy_to(decode=True)
415            self.assertIsInstance(ret, Iterable)
416            rows = list(ret)
417            self.assertEqual(len(rows), 3)
418            rows = ''.join(rows)
419            self.assertIsInstance(rows, unicode)
420            self.assertEqual(rows, self.data_text.decode('utf-8'))
421
422    def test_decode(self):
423        ret_raw = b''.join(self.copy_to(decode=False))
424        ret_decoded = ''.join(self.copy_to(decode=True))
425        self.assertIsInstance(ret_raw, bytes)
426        self.assertIsInstance(ret_decoded, unicode)
427        self.assertEqual(ret_decoded, ret_raw.decode('utf-8'))
428
429    def test_sep(self):
430        ret = list(self.copy_to(sep='-'))
431        self.assertEqual(ret, ['%d-%s\n' % row for row in self.data])
432
433    def test_null(self):
434        data = ['%d\t%s\n' % row for row in self.data]
435        self.cursor.execute('insert into copytest values(4, null)')
436        try:
437            ret = list(self.copy_to())
438            self.assertEqual(ret, data + ['4\t\\N\n'])
439            ret = list(self.copy_to(null='Nix'))
440            self.assertEqual(ret, data + ['4\tNix\n'])
441            ret = list(self.copy_to(null=''))
442            self.assertEqual(ret, data + ['4\t\n'])
443        finally:
444            self.cursor.execute('delete from copytest where id=4')
445
446    def test_columns(self):
447        data_id = ''.join('%d\n' % row[0] for row in self.data)
448        data_name = ''.join('%s\n' % row[1] for row in self.data)
449        ret = ''.join(self.copy_to(columns='id'))
450        self.assertEqual(ret, data_id)
451        ret = ''.join(self.copy_to(columns=['id']))
452        self.assertEqual(ret, data_id)
453        ret = ''.join(self.copy_to(columns='name'))
454        self.assertEqual(ret, data_name)
455        ret = ''.join(self.copy_to(columns=['name']))
456        self.assertEqual(ret, data_name)
457        ret = ''.join(self.copy_to(columns='id, name'))
458        self.assertEqual(ret, self.data_text)
459        ret = ''.join(self.copy_to(columns=['id', 'name']))
460        self.assertEqual(ret, self.data_text)
461        self.assertRaises(pgdb.ProgrammingError, self.copy_to,
462            columns=['id', 'age'])
463
464    def test_csv(self):
465        ret = self.copy_to(format='csv')
466        self.assertIsInstance(ret, Iterable)
467        rows = list(ret)
468        self.assertEqual(len(rows), 3)
469        rows = ''.join(rows)
470        self.assertIsInstance(rows, str)
471        self.assertEqual(rows, self.data_csv)
472
473    def test_csv_with_sep(self):
474        rows = ''.join(self.copy_to(format='csv', sep=';'))
475        self.assertEqual(rows, self.data_csv.replace(',', ';'))
476
477    def test_binary(self):
478        ret = self.copy_to(format='binary')
479        self.assertIsInstance(ret, Iterable)
480        for row in ret:
481            self.assertTrue(row.startswith(b'PGCOPY\n\377\r\n\0'))
482            break
483
484    def test_binary_with_sep(self):
485        self.assertRaises(ValueError, self.copy_to, format='binary', sep='\t')
486
487    def test_binary_with_unicode(self):
488        self.assertRaises(ValueError, self.copy_to,
489            format='binary', decode=True)
490
491    def test_query(self):
492        ret = self.cursor.copy_to(None,
493            "select name||'!' from copytest where id=1941")
494        self.assertIsInstance(ret, Iterable)
495        rows = list(ret)
496        self.assertEqual(len(rows), 1)
497        self.assertIsInstance(rows[0], str)
498        self.assertEqual(rows[0], '%s!\n' % self.data[1][1])
499
500    def test_file(self):
501        stream = self.data_file
502        ret = self.copy_to(stream)
503        self.assertIs(ret, self.cursor)
504        self.assertEqual(str(stream), self.data_text)
505        data = self.data_text
506        if str is unicode:
507            data = data.encode('utf-8')
508        sizes = [len(row) + 1 for row in data.splitlines()]
509        self.assertEqual(stream.sizes, sizes)
510
511
512class TestBinary(TestCopy):
513    """Test the copy_from and copy_to methods with binary data."""
514
515    def test_round_trip(self):
516        # fill table from textual data
517        self.cursor.copy_from(self.data_text, 'copytest', format='text')
518        self.check_table()
519        # get data back in binary format
520        ret = self.cursor.copy_to(None, 'copytest', format='binary')
521        self.assertIsInstance(ret, Iterable)
522        data_binary = b''.join(ret)
523        self.assertTrue(data_binary.startswith(b'PGCOPY\n\377\r\n\0'))
524        self.truncate_table()
525        # fill table from binary data
526        self.cursor.copy_from(data_binary, 'copytest', format='binary')
527        self.check_table()
528
529
530if __name__ == '__main__':
531    unittest.main()
Note: See TracBrowser for help on using the repository browser.