source: trunk/tests/test_dbapi20_copy.py @ 788

Last change on this file since 788 was 772, checked in by cito, 4 years ago

Improve test coverage for the pgdb module

Includes a simple patch that allows storing Python lists or tuple values
in PostgreSQL array fields (they are not yet converted when read, though).

Also re-activated the shortcut methods on the connection again
since they can be sometimes useful.

Test coverage is now around 95%, the remaining lines are due to support for
old Python versions or obscure database errors that can't easily be aroused.

  • Property svn:keywords set to Id
File size: 19.1 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the modern PyGreSQL interface.
5
6Sub-tests for the copy methods.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17
18from collections import Iterable
19
20import pgdb  # the module under test
21
22# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
23# get our information from that.  Otherwise we use the defaults.
24# The current user must have create schema privilege on the database.
25dbname = 'unittest'
26dbhost = None
27dbport = 5432
28
29try:
30    from .LOCAL_PyGreSQL import *
31except (ImportError, ValueError):
32    try:
33        from LOCAL_PyGreSQL import *
34    except ImportError:
35        pass
36
37try:
38    unicode
39except NameError:  # Python >= 3.0
40    unicode = str
41
42
43class InputStream:
44
45    def __init__(self, data):
46        if isinstance(data, unicode):
47            data = data.encode('utf-8')
48        self.data = data or b''
49        self.sizes = []
50
51    def __str__(self):
52        data = self.data
53        if str is unicode:  # Python >= 3.0
54            data = data.decode('utf-8')
55        return data
56
57    def __len__(self):
58        return len(self.data)
59
60    def read(self, size=None):
61        if size is None:
62            output, data = self.data, b''
63        else:
64            output, data = self.data[:size], self.data[size:]
65        self.data = data
66        self.sizes.append(size)
67        return output
68
69
70class OutputStream:
71
72    def __init__(self):
73        self.data = b''
74        self.sizes = []
75
76    def __str__(self):
77        data = self.data
78        if str is unicode:  # Python >= 3.0
79            data = data.decode('utf-8')
80        return data
81
82    def __len__(self):
83        return len(self.data)
84
85    def write(self, data):
86        if isinstance(data, unicode):
87            data = data.encode('utf-8')
88        self.data += data
89        self.sizes.append(len(data))
90
91
92class TestStreams(unittest.TestCase):
93
94    def test_input(self):
95        stream = InputStream('Hello, Wörld!')
96        self.assertIsInstance(stream.data, bytes)
97        self.assertEqual(stream.data, b'Hello, W\xc3\xb6rld!')
98        self.assertIsInstance(str(stream), str)
99        self.assertEqual(str(stream), 'Hello, Wörld!')
100        self.assertEqual(len(stream), 14)
101        self.assertEqual(stream.read(3), b'Hel')
102        self.assertEqual(stream.read(2), b'lo')
103        self.assertEqual(stream.read(1), b',')
104        self.assertEqual(stream.read(1), b' ')
105        self.assertEqual(stream.read(), b'W\xc3\xb6rld!')
106        self.assertEqual(stream.read(), b'')
107        self.assertEqual(len(stream), 0)
108        self.assertEqual(stream.sizes, [3, 2, 1, 1, None, None])
109
110    def test_output(self):
111        stream = OutputStream()
112        self.assertEqual(len(stream), 0)
113        for chunk in 'Hel', 'lo', ',', ' ', 'Wörld!':
114            stream.write(chunk)
115        self.assertIsInstance(stream.data, bytes)
116        self.assertEqual(stream.data, b'Hello, W\xc3\xb6rld!')
117        self.assertIsInstance(str(stream), str)
118        self.assertEqual(str(stream), 'Hello, Wörld!')
119        self.assertEqual(len(stream), 14)
120        self.assertEqual(stream.sizes, [3, 2, 1, 1, 7])
121
122
123class TestCopy(unittest.TestCase):
124
125    cls_set_up = False
126
127    @staticmethod
128    def connect():
129        return pgdb.connect(database=dbname,
130            host='%s:%d' % (dbhost or '', dbport or -1))
131
132    @classmethod
133    def setUpClass(cls):
134        con = cls.connect()
135        cur = con.cursor()
136        cur.execute("set client_min_messages=warning")
137        cur.execute("drop table if exists copytest cascade")
138        cur.execute("create table copytest ("
139            "id smallint primary key, name varchar(64))")
140        cur.close()
141        con.commit()
142        con.close()
143        cls.cls_set_up = True
144
145    @classmethod
146    def tearDownClass(cls):
147        con = cls.connect()
148        cur = con.cursor()
149        cur.execute("set client_min_messages=warning")
150        cur.execute("drop table if exists copytest cascade")
151        con.commit()
152        con.close()
153
154    def setUp(self):
155        self.assertTrue(self.cls_set_up)
156        self.con = self.connect()
157        self.cursor = self.con.cursor()
158        self.cursor.execute("set client_encoding=utf8")
159
160    def tearDown(self):
161        try:
162            self.cursor.close()
163        except Exception:
164            pass
165        try:
166            self.con.rollback()
167        except Exception:
168            pass
169        try:
170            self.con.close()
171        except Exception:
172            pass
173
174    data = [(1935, 'Luciano Pavarotti'),
175            (1941, 'Plácido Domingo'),
176            (1946, 'José Carreras')]
177
178    @property
179    def data_text(self):
180        return ''.join('%d\t%s\n' % row for row in self.data)
181
182    @property
183    def data_csv(self):
184        return ''.join('%d,%s\n' % row for row in self.data)
185
186    def truncate_table(self):
187        self.cursor.execute("truncate table copytest")
188
189    @property
190    def table_data(self):
191        self.cursor.execute("select * from copytest")
192        return self.cursor.fetchall()
193
194    def check_table(self):
195        self.assertEqual(self.table_data, self.data)
196
197    def check_rowcount(self, number=len(data)):
198        self.assertEqual(self.cursor.rowcount, number)
199
200
201class TestCopyFrom(TestCopy):
202    """Test the copy_from method."""
203
204    def tearDown(self):
205        super(TestCopyFrom, self).tearDown()
206        self.setUp()
207        self.truncate_table()
208        super(TestCopyFrom, self).tearDown()
209
210    def copy_from(self, stream, **options):
211        return self.cursor.copy_from(stream, 'copytest', **options)
212
213    @property
214    def data_file(self):
215        return InputStream(self.data_text)
216
217    def test_bad_params(self):
218        call = self.cursor.copy_from
219        call('0\t', 'copytest'), self.cursor
220        call('1\t', 'copytest',
221             format='text', sep='\t', null='', columns=['id', 'name'])
222        self.assertRaises(TypeError, call)
223        self.assertRaises(TypeError, call, None)
224        self.assertRaises(TypeError, call, None, None)
225        self.assertRaises(TypeError, call, '0\t')
226        self.assertRaises(TypeError, call, '0\t', None)
227        self.assertRaises(TypeError, call, '0\t', 42)
228        self.assertRaises(TypeError, call, '0\t', ['copytest'])
229        self.assertRaises(TypeError, call, '0\t', 'copytest', format=42)
230        self.assertRaises(ValueError, call, '0\t', 'copytest', format='bad')
231        self.assertRaises(TypeError, call, '0\t', 'copytest', sep=42)
232        self.assertRaises(ValueError, call, '0\t', 'copytest', sep='bad')
233        self.assertRaises(TypeError, call, '0\t', 'copytest', null=42)
234        self.assertRaises(ValueError, call, '0\t', 'copytest', size='bad')
235        self.assertRaises(TypeError, call, '0\t', 'copytest', columns=42)
236        self.assertRaises(ValueError, call, b'', 'copytest',
237            format='binary', sep=',')
238
239    def test_input_string(self):
240        ret = self.copy_from('42\tHello, world!')
241        self.assertIs(ret, self.cursor)
242        self.assertEqual(self.table_data, [(42, 'Hello, world!')])
243        self.check_rowcount(1)
244
245    def test_input_string_with_newline(self):
246        self.copy_from('42\tHello, world!\n')
247        self.assertEqual(self.table_data, [(42, 'Hello, world!')])
248        self.check_rowcount(1)
249
250    def test_input_string_multiple_rows(self):
251        ret = self.copy_from(self.data_text)
252        self.assertIs(ret, self.cursor)
253        self.check_table()
254        self.check_rowcount()
255
256    if str is unicode:  # Python >= 3.0
257
258        def test_input_bytes(self):
259            self.copy_from(b'42\tHello, world!')
260            self.assertEqual(self.table_data, [(42, 'Hello, world!')])
261            self.truncate_table()
262            self.copy_from(self.data_text.encode('utf-8'))
263            self.check_table()
264
265    else:  # Python < 3.0
266
267        def test_input_unicode(self):
268            self.copy_from(u'43\tWÃŒrstel, KÀse!')
269            self.assertEqual(self.table_data, [(43, 'WÃŒrstel, KÀse!')])
270            self.truncate_table()
271            self.copy_from(self.data_text.decode('utf-8'))
272            self.check_table()
273
274    def test_input_iterable(self):
275        self.copy_from(self.data_text.splitlines())
276        self.check_table()
277        self.check_rowcount()
278
279    def test_input_iterable_invalid(self):
280        self.assertRaises(IOError, self.copy_from, [None])
281
282    def test_input_iterable_with_newlines(self):
283        self.copy_from('%s\n' % row for row in self.data_text.splitlines())
284        self.check_table()
285
286    if str is unicode:  # Python >= 3.0
287
288        def test_input_iterable_bytes(self):
289            self.copy_from(row.encode('utf-8')
290                for row in self.data_text.splitlines())
291            self.check_table()
292
293    def test_sep(self):
294        stream = ('%d-%s' % row for row in self.data)
295        self.copy_from(stream, sep='-')
296        self.check_table()
297
298    def test_null(self):
299        self.copy_from('0\t\\N')
300        self.assertEqual(self.table_data, [(0, None)])
301        self.assertIsNone(self.table_data[0][1])
302        self.truncate_table()
303        self.copy_from('1\tNix')
304        self.assertEqual(self.table_data, [(1, 'Nix')])
305        self.assertIsNotNone(self.table_data[0][1])
306        self.truncate_table()
307        self.copy_from('2\tNix', null='Nix')
308        self.assertEqual(self.table_data, [(2, None)])
309        self.assertIsNone(self.table_data[0][1])
310        self.truncate_table()
311        self.copy_from('3\t')
312        self.assertEqual(self.table_data, [(3, '')])
313        self.assertIsNotNone(self.table_data[0][1])
314        self.truncate_table()
315        self.copy_from('4\t', null='')
316        self.assertEqual(self.table_data, [(4, None)])
317        self.assertIsNone(self.table_data[0][1])
318
319    def test_columns(self):
320        self.copy_from('1', columns='id')
321        self.copy_from('2', columns=['id'])
322        self.copy_from('3\tThree')
323        self.copy_from('4\tFour', columns='id, name')
324        self.copy_from('5\tFive', columns=['id', 'name'])
325        self.assertEqual(self.table_data, [
326            (1, None), (2, None), (3, 'Three'), (4, 'Four'), (5, 'Five')])
327        self.check_rowcount(5)
328        self.assertRaises(pgdb.ProgrammingError, self.copy_from,
329            '6\t42', columns=['id', 'age'])
330        self.check_rowcount(-1)
331
332    def test_csv(self):
333        self.copy_from(self.data_csv, format='csv')
334        self.check_table()
335
336    def test_csv_with_sep(self):
337        stream = ('%d;"%s"\n' % row for row in self.data)
338        self.copy_from(stream, format='csv', sep=';')
339        self.check_table()
340        self.check_rowcount()
341
342    def test_binary(self):
343        self.assertRaises(IOError, self.copy_from,
344            b'NOPGCOPY\n', format='binary')
345        self.check_rowcount(-1)
346
347    def test_binary_with_sep(self):
348        self.assertRaises(ValueError, self.copy_from,
349            '', format='binary', sep='\t')
350
351    def test_binary_with_unicode(self):
352        self.assertRaises(ValueError, self.copy_from, u'', format='binary')
353
354    def test_query(self):
355        self.assertRaises(ValueError, self.cursor.copy_from, '', "select null")
356
357    def test_file(self):
358        stream = self.data_file
359        ret = self.copy_from(stream)
360        self.assertIs(ret, self.cursor)
361        self.check_table()
362        self.assertEqual(len(stream), 0)
363        self.assertEqual(stream.sizes, [8192])
364        self.check_rowcount()
365
366    def test_size_positive(self):
367        stream = self.data_file
368        size = 7
369        num_chunks = (len(stream) + size - 1) // size
370        self.copy_from(stream, size=size)
371        self.check_table()
372        self.assertEqual(len(stream), 0)
373        self.assertEqual(stream.sizes, [size] * num_chunks)
374        self.check_rowcount()
375
376    def test_size_negative(self):
377        stream = self.data_file
378        self.copy_from(stream, size=-1)
379        self.check_table()
380        self.assertEqual(len(stream), 0)
381        self.assertEqual(stream.sizes, [None])
382        self.check_rowcount()
383
384    def test_size_invalid(self):
385        self.assertRaises(TypeError,
386            self.copy_from, self.data_file, size='invalid')
387
388
389class TestCopyTo(TestCopy):
390    """Test the copy_to method."""
391
392    @classmethod
393    def setUpClass(cls):
394        super(TestCopyTo, cls).setUpClass()
395        con = cls.connect()
396        cur = con.cursor()
397        cur.execute("insert into copytest values (%d, %s)", cls.data)
398        cur.close()
399        con.commit()
400        con.close()
401
402    def copy_to(self, stream=None, **options):
403        return self.cursor.copy_to(stream, 'copytest', **options)
404
405    @property
406    def data_file(self):
407        return OutputStream()
408
409    def test_bad_params(self):
410        call = self.cursor.copy_to
411        call(None, 'copytest')
412        call(None, 'copytest',
413             format='text', sep='\t', null='', columns=['id', 'name'])
414        self.assertRaises(TypeError, call)
415        self.assertRaises(TypeError, call, None)
416        self.assertRaises(TypeError, call, None, 42)
417        self.assertRaises(TypeError, call, None, ['copytest'])
418        self.assertRaises(TypeError, call, 'bad', 'copytest')
419        self.assertRaises(TypeError, call, None, 'copytest', format=42)
420        self.assertRaises(ValueError, call, None, 'copytest', format='bad')
421        self.assertRaises(TypeError, call, None, 'copytest', sep=42)
422        self.assertRaises(ValueError, call, None, 'copytest', sep='bad')
423        self.assertRaises(TypeError, call, None, 'copytest', null=42)
424        self.assertRaises(TypeError, call, None, 'copytest', decode='bad')
425        self.assertRaises(TypeError, call, None, 'copytest', columns=42)
426
427    def test_generator(self):
428        ret = self.copy_to()
429        self.assertIsInstance(ret, Iterable)
430        rows = list(ret)
431        self.assertEqual(len(rows), 3)
432        rows = ''.join(rows)
433        self.assertIsInstance(rows, str)
434        self.assertEqual(rows, self.data_text)
435        self.check_rowcount()
436
437    if str is unicode:  # Python >= 3.0
438
439        def test_generator_bytes(self):
440            ret = self.copy_to(decode=False)
441            self.assertIsInstance(ret, Iterable)
442            rows = list(ret)
443            self.assertEqual(len(rows), 3)
444            rows = b''.join(rows)
445            self.assertIsInstance(rows, bytes)
446            self.assertEqual(rows, self.data_text.encode('utf-8'))
447
448    else:  # Python < 3.0
449
450        def test_generator_unicode(self):
451            ret = self.copy_to(decode=True)
452            self.assertIsInstance(ret, Iterable)
453            rows = list(ret)
454            self.assertEqual(len(rows), 3)
455            rows = ''.join(rows)
456            self.assertIsInstance(rows, unicode)
457            self.assertEqual(rows, self.data_text.decode('utf-8'))
458
459    def test_rowcount_increment(self):
460        ret = self.copy_to()
461        self.assertIsInstance(ret, Iterable)
462        for n, row in enumerate(ret):
463            self.check_rowcount(n + 1)
464
465    def test_decode(self):
466        ret_raw = b''.join(self.copy_to(decode=False))
467        ret_decoded = ''.join(self.copy_to(decode=True))
468        self.assertIsInstance(ret_raw, bytes)
469        self.assertIsInstance(ret_decoded, unicode)
470        self.assertEqual(ret_decoded, ret_raw.decode('utf-8'))
471        self.check_rowcount()
472
473    def test_sep(self):
474        ret = list(self.copy_to(sep='-'))
475        self.assertEqual(ret, ['%d-%s\n' % row for row in self.data])
476
477    def test_null(self):
478        data = ['%d\t%s\n' % row for row in self.data]
479        self.cursor.execute('insert into copytest values(4, null)')
480        try:
481            ret = list(self.copy_to())
482            self.assertEqual(ret, data + ['4\t\\N\n'])
483            ret = list(self.copy_to(null='Nix'))
484            self.assertEqual(ret, data + ['4\tNix\n'])
485            ret = list(self.copy_to(null=''))
486            self.assertEqual(ret, data + ['4\t\n'])
487        finally:
488            self.cursor.execute('delete from copytest where id=4')
489
490    def test_columns(self):
491        data_id = ''.join('%d\n' % row[0] for row in self.data)
492        data_name = ''.join('%s\n' % row[1] for row in self.data)
493        ret = ''.join(self.copy_to(columns='id'))
494        self.assertEqual(ret, data_id)
495        ret = ''.join(self.copy_to(columns=['id']))
496        self.assertEqual(ret, data_id)
497        ret = ''.join(self.copy_to(columns='name'))
498        self.assertEqual(ret, data_name)
499        ret = ''.join(self.copy_to(columns=['name']))
500        self.assertEqual(ret, data_name)
501        ret = ''.join(self.copy_to(columns='id, name'))
502        self.assertEqual(ret, self.data_text)
503        ret = ''.join(self.copy_to(columns=['id', 'name']))
504        self.assertEqual(ret, self.data_text)
505        self.assertRaises(pgdb.ProgrammingError, self.copy_to,
506            columns=['id', 'age'])
507
508    def test_csv(self):
509        ret = self.copy_to(format='csv')
510        self.assertIsInstance(ret, Iterable)
511        rows = list(ret)
512        self.assertEqual(len(rows), 3)
513        rows = ''.join(rows)
514        self.assertIsInstance(rows, str)
515        self.assertEqual(rows, self.data_csv)
516        self.check_rowcount(3)
517
518    def test_csv_with_sep(self):
519        rows = ''.join(self.copy_to(format='csv', sep=';'))
520        self.assertEqual(rows, self.data_csv.replace(',', ';'))
521
522    def test_binary(self):
523        ret = self.copy_to(format='binary')
524        self.assertIsInstance(ret, Iterable)
525        for row in ret:
526            self.assertTrue(row.startswith(b'PGCOPY\n\377\r\n\0'))
527            break
528        self.check_rowcount(1)
529
530    def test_binary_with_sep(self):
531        self.assertRaises(ValueError, self.copy_to, format='binary', sep='\t')
532
533    def test_binary_with_unicode(self):
534        self.assertRaises(ValueError, self.copy_to,
535            format='binary', decode=True)
536
537    def test_query(self):
538        self.assertRaises(ValueError, self.cursor.copy_to, None,
539            "select name from copytest", columns='noname')
540        ret = self.cursor.copy_to(None,
541            "select name||'!' from copytest where id=1941")
542        self.assertIsInstance(ret, Iterable)
543        rows = list(ret)
544        self.assertEqual(len(rows), 1)
545        self.assertIsInstance(rows[0], str)
546        self.assertEqual(rows[0], '%s!\n' % self.data[1][1])
547        self.check_rowcount(1)
548
549    def test_file(self):
550        stream = self.data_file
551        ret = self.copy_to(stream)
552        self.assertIs(ret, self.cursor)
553        self.assertEqual(str(stream), self.data_text)
554        data = self.data_text
555        if str is unicode:  # Python >= 3.0
556            data = data.encode('utf-8')
557        sizes = [len(row) + 1 for row in data.splitlines()]
558        self.assertEqual(stream.sizes, sizes)
559        self.check_rowcount()
560
561
562class TestBinary(TestCopy):
563    """Test the copy_from and copy_to methods with binary data."""
564
565    def test_round_trip(self):
566        # fill table from textual data
567        self.cursor.copy_from(self.data_text, 'copytest', format='text')
568        self.check_table()
569        self.check_rowcount()
570        # get data back in binary format
571        ret = self.cursor.copy_to(None, 'copytest', format='binary')
572        self.assertIsInstance(ret, Iterable)
573        data_binary = b''.join(ret)
574        self.assertTrue(data_binary.startswith(b'PGCOPY\n\377\r\n\0'))
575        self.check_rowcount()
576        self.truncate_table()
577        # fill table from binary data
578        self.cursor.copy_from(data_binary, 'copytest', format='binary')
579        self.check_table()
580        self.check_rowcount()
581
582
583if __name__ == '__main__':
584    unittest.main()
Note: See TracBrowser for help on using the repository browser.