source: trunk/module/tests/test_dbapi20_copy.py @ 697

Last change on this file since 697 was 697, checked in by cito, 4 years ago

Test rowcount attribute with copy methods

  • Property svn:keywords set to Id
File size: 18.1 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the modern PyGreSQL interface.
5
6Sub-tests for the copy methods.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18
19from collections import Iterable
20
21import pgdb  # the module under test
22
23# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
24# get our information from that.  Otherwise we use the defaults.
25# The current user must have create schema privilege on the database.
26dbname = 'unittest'
27dbhost = None
28dbport = 5432
29
30try:
31    from .LOCAL_PyGreSQL import *
32except (ImportError, ValueError):
33    try:
34        from LOCAL_PyGreSQL import *
35    except ImportError:
36        pass
37
38try:
39    unicode
40except NameError:  # Python >= 3.0
41    unicode = str
42
43
44class InputStream:
45
46    def __init__(self, data):
47        if isinstance(data, unicode):
48            data = data.encode('utf-8')
49        self.data = data or b''
50        self.sizes = []
51
52    def __str__(self):
53        data = self.data
54        if str is unicode:
55            data = data.decode('utf-8')
56        return data
57
58    def __len__(self):
59        return len(self.data)
60
61    def read(self, size=None):
62        if size is None:
63            output, data = self.data, b''
64        else:
65            output, data = self.data[:size], self.data[size:]
66        self.data = data
67        self.sizes.append(size)
68        return output
69
70
71class OutputStream:
72
73    def __init__(self):
74        self.data = b''
75        self.sizes = []
76
77    def __str__(self):
78        data = self.data
79        if str is unicode:
80            data = data.decode('utf-8')
81        return data
82
83    def __len__(self):
84        return len(self.data)
85
86    def write(self, data):
87        if isinstance(data, unicode):
88            data = data.encode('utf-8')
89        self.data += data
90        self.sizes.append(len(data))
91
92
93class TestStreams(unittest.TestCase):
94
95    def test_input(self):
96        stream = InputStream('Hello, Wörld!')
97        self.assertIsInstance(stream.data, bytes)
98        self.assertEqual(stream.data, b'Hello, W\xc3\xb6rld!')
99        self.assertIsInstance(str(stream), str)
100        self.assertEqual(str(stream), 'Hello, Wörld!')
101        self.assertEqual(len(stream), 14)
102        self.assertEqual(stream.read(3), b'Hel')
103        self.assertEqual(stream.read(2), b'lo')
104        self.assertEqual(stream.read(1), b',')
105        self.assertEqual(stream.read(1), b' ')
106        self.assertEqual(stream.read(), b'W\xc3\xb6rld!')
107        self.assertEqual(stream.read(), b'')
108        self.assertEqual(len(stream), 0)
109        self.assertEqual(stream.sizes, [3, 2, 1, 1, None, None])
110
111    def test_output(self):
112        stream = OutputStream()
113        self.assertEqual(len(stream), 0)
114        for chunk in 'Hel', 'lo', ',', ' ', 'Wörld!':
115            stream.write(chunk)
116        self.assertIsInstance(stream.data, bytes)
117        self.assertEqual(stream.data, b'Hello, W\xc3\xb6rld!')
118        self.assertIsInstance(str(stream), str)
119        self.assertEqual(str(stream), 'Hello, Wörld!')
120        self.assertEqual(len(stream), 14)
121        self.assertEqual(stream.sizes, [3, 2, 1, 1, 7])
122
123
124class TestCopy(unittest.TestCase):
125
126    @staticmethod
127    def connect():
128        return pgdb.connect(database=dbname,
129            host='%s:%d' % (dbhost or '', dbport or -1))
130
131    @classmethod
132    def setUpClass(cls):
133        con = cls.connect()
134        cur = con.cursor()
135        cur.execute("set client_min_messages=warning")
136        cur.execute("drop table if exists copytest cascade")
137        cur.execute("create table copytest ("
138            "id smallint primary key, name varchar(64))")
139        cur.close()
140        con.commit()
141        con.close()
142
143    @classmethod
144    def tearDownClass(cls):
145        con = cls.connect()
146        cur = con.cursor()
147        cur.execute("set client_min_messages=warning")
148        cur.execute("drop table if exists copytest cascade")
149        con.commit()
150        con.close()
151
152    def setUp(self):
153        self.con = self.connect()
154        self.cursor = self.con.cursor()
155        self.cursor.execute("set client_encoding=utf8")
156
157    def tearDown(self):
158        try:
159            self.cursor.close()
160        except Exception:
161            pass
162        try:
163            self.con.rollback()
164        except Exception:
165            pass
166        try:
167            self.con.close()
168        except Exception:
169            pass
170
171    data = [(1935, 'Luciano Pavarotti'),
172            (1941, 'Plácido Domingo'),
173            (1946, 'José Carreras')]
174
175    @property
176    def data_text(self):
177        return ''.join('%d\t%s\n' % row for row in self.data)
178
179    @property
180    def data_csv(self):
181        return ''.join('%d,%s\n' % row for row in self.data)
182
183    def truncate_table(self):
184        self.cursor.execute("truncate table copytest")
185
186    @property
187    def table_data(self):
188        self.cursor.execute("select * from copytest")
189        return self.cursor.fetchall()
190
191    def check_table(self):
192        self.assertEqual(self.table_data, self.data)
193
194    def check_rowcount(self, number=len(data)):
195        self.assertEqual(self.cursor.rowcount, number)
196
197
198class TestCopyFrom(TestCopy):
199    """Test the copy_from method."""
200
201    def tearDown(self):
202        super(TestCopyFrom, self).tearDown()
203        self.setUp()
204        self.truncate_table()
205        super(TestCopyFrom, self).tearDown()
206
207    def copy_from(self, stream, **options):
208        return self.cursor.copy_from(stream, 'copytest', **options)
209
210    @property
211    def data_file(self):
212        return InputStream(self.data_text)
213
214    def test_bad_params(self):
215        call = self.cursor.copy_from
216        call('0\t', 'copytest'), self.cursor
217        call('1\t', 'copytest',
218             format='text', sep='\t', null='', columns=['id', 'name'])
219        self.assertRaises(TypeError, call)
220        self.assertRaises(TypeError, call, '0\t')
221        self.assertRaises(TypeError, call, '0\t', 42)
222        self.assertRaises(TypeError, call, '0\t', ['copytest'])
223        self.assertRaises(TypeError, call, '0\t', 'copytest', format=42)
224        self.assertRaises(ValueError, call, '0\t', 'copytest', format='bad')
225        self.assertRaises(TypeError, call, '0\t', 'copytest', sep=42)
226        self.assertRaises(ValueError, call, '0\t', 'copytest', sep='bad')
227        self.assertRaises(TypeError, call, '0\t', 'copytest', null=42)
228        self.assertRaises(ValueError, call, '0\t', 'copytest', size='bad')
229        self.assertRaises(TypeError, call, '0\t', 'copytest', columns=42)
230
231    def test_input_string(self):
232        ret = self.copy_from('42\tHello, world!')
233        self.assertIs(ret, self.cursor)
234        self.assertEqual(self.table_data, [(42, 'Hello, world!')])
235        self.check_rowcount(1)
236
237    def test_input_string_with_newline(self):
238        self.copy_from('42\tHello, world!\n')
239        self.assertEqual(self.table_data, [(42, 'Hello, world!')])
240        self.check_rowcount(1)
241
242    def test_input_string_multiple_rows(self):
243        ret = self.copy_from(self.data_text)
244        self.assertIs(ret, self.cursor)
245        self.check_table()
246        self.check_rowcount()
247
248    if str is unicode:
249
250        def test_input_bytes(self):
251            self.copy_from(b'42\tHello, world!')
252            self.assertEqual(self.table_data, [(42, 'Hello, world!')])
253            self.truncate_table()
254            self.copy_from(self.data_text.encode('utf-8'))
255            self.check_table()
256
257    if str is not unicode:
258
259        def test_input_unicode(self):
260            self.copy_from(u'43\tWÃŒrstel, KÀse!')
261            self.assertEqual(self.table_data, [(43, 'WÃŒrstel, KÀse!')])
262            self.truncate_table()
263            self.copy_from(self.data_text.decode('utf-8'))
264            self.check_table()
265
266    def test_input_iterable(self):
267        self.copy_from(self.data_text.splitlines())
268        self.check_table()
269        self.check_rowcount()
270
271    def test_input_iterable_with_newlines(self):
272        self.copy_from('%s\n' % row for row in self.data_text.splitlines())
273        self.check_table()
274
275    def test_sep(self):
276        stream = ('%d-%s' % row for row in self.data)
277        self.copy_from(stream, sep='-')
278        self.check_table()
279
280    def test_null(self):
281        self.copy_from('0\t\\N')
282        self.assertEqual(self.table_data, [(0, None)])
283        self.assertIsNone(self.table_data[0][1])
284        self.truncate_table()
285        self.copy_from('1\tNix')
286        self.assertEqual(self.table_data, [(1, 'Nix')])
287        self.assertIsNotNone(self.table_data[0][1])
288        self.truncate_table()
289        self.copy_from('2\tNix', null='Nix')
290        self.assertEqual(self.table_data, [(2, None)])
291        self.assertIsNone(self.table_data[0][1])
292        self.truncate_table()
293        self.copy_from('3\t')
294        self.assertEqual(self.table_data, [(3, '')])
295        self.assertIsNotNone(self.table_data[0][1])
296        self.truncate_table()
297        self.copy_from('4\t', null='')
298        self.assertEqual(self.table_data, [(4, None)])
299        self.assertIsNone(self.table_data[0][1])
300
301    def test_columns(self):
302        self.copy_from('1', columns='id')
303        self.copy_from('2', columns=['id'])
304        self.copy_from('3\tThree')
305        self.copy_from('4\tFour', columns='id, name')
306        self.copy_from('5\tFive', columns=['id', 'name'])
307        self.assertEqual(self.table_data, [
308            (1, None), (2, None), (3, 'Three'), (4, 'Four'), (5, 'Five')])
309        self.check_rowcount(5)
310        self.assertRaises(pgdb.ProgrammingError, self.copy_from,
311            '6\t42', columns=['id', 'age'])
312        self.check_rowcount(-1)
313
314    def test_csv(self):
315        self.copy_from(self.data_csv, format='csv')
316        self.check_table()
317
318    def test_csv_with_sep(self):
319        stream = ('%d;"%s"\n' % row for row in self.data)
320        self.copy_from(stream, format='csv', sep=';')
321        self.check_table()
322        self.check_rowcount()
323
324    def test_binary(self):
325        self.assertRaises(IOError, self.copy_from,
326            b'NOPGCOPY\n', format='binary')
327        self.check_rowcount(-1)
328
329    def test_binary_with_sep(self):
330        self.assertRaises(ValueError, self.copy_from,
331            '', format='binary', sep='\t')
332
333    def test_binary_with_unicode(self):
334        self.assertRaises(ValueError, self.copy_from, u'', format='binary')
335
336    def test_query(self):
337        self.assertRaises(ValueError, self.cursor.copy_from, '', "select null")
338
339    def test_file(self):
340        stream = self.data_file
341        ret = self.copy_from(stream)
342        self.assertIs(ret, self.cursor)
343        self.check_table()
344        self.assertEqual(len(stream), 0)
345        self.assertEqual(stream.sizes, [8192])
346        self.check_rowcount()
347
348    def test_size_positive(self):
349        stream = self.data_file
350        size = 7
351        num_chunks = (len(stream) + size - 1) // size
352        self.copy_from(stream, size=size)
353        self.check_table()
354        self.assertEqual(len(stream), 0)
355        self.assertEqual(stream.sizes, [size] * num_chunks)
356        self.check_rowcount()
357
358    def test_size_negative(self):
359        stream = self.data_file
360        self.copy_from(stream, size=-1)
361        self.check_table()
362        self.assertEqual(len(stream), 0)
363        self.assertEqual(stream.sizes, [None])
364        self.check_rowcount()
365
366
367class TestCopyTo(TestCopy):
368    """Test the copy_to method."""
369
370    @classmethod
371    def setUpClass(cls):
372        super(TestCopyTo, cls).setUpClass()
373        con = cls.connect()
374        cur = con.cursor()
375        cur.execute("insert into copytest values (%d, %s)", cls.data)
376        cur.close()
377        con.commit()
378        con.close()
379
380    def copy_to(self, stream=None, **options):
381        return self.cursor.copy_to(stream, 'copytest', **options)
382
383    @property
384    def data_file(self):
385        return OutputStream()
386
387    def test_bad_params(self):
388        call = self.cursor.copy_to
389        call(None, 'copytest')
390        call(None, 'copytest',
391             format='text', sep='\t', null='', columns=['id', 'name'])
392        self.assertRaises(TypeError, call)
393        self.assertRaises(TypeError, call, None)
394        self.assertRaises(TypeError, call, None, 42)
395        self.assertRaises(TypeError, call, None, ['copytest'])
396        self.assertRaises(TypeError, call, 'bad', 'copytest')
397        self.assertRaises(TypeError, call, None, 'copytest', format=42)
398        self.assertRaises(ValueError, call, None, 'copytest', format='bad')
399        self.assertRaises(TypeError, call, None, 'copytest', sep=42)
400        self.assertRaises(ValueError, call, None, 'copytest', sep='bad')
401        self.assertRaises(TypeError, call, None, 'copytest', null=42)
402        self.assertRaises(TypeError, call, None, 'copytest', decode='bad')
403        self.assertRaises(TypeError, call, None, 'copytest', columns=42)
404
405    def test_generator(self):
406        ret = self.copy_to()
407        self.assertIsInstance(ret, Iterable)
408        rows = list(ret)
409        self.assertEqual(len(rows), 3)
410        rows = ''.join(rows)
411        self.assertIsInstance(rows, str)
412        self.assertEqual(rows, self.data_text)
413        self.check_rowcount()
414
415    if str is unicode:
416
417        def test_generator_bytes(self):
418            ret = self.copy_to(decode=False)
419            self.assertIsInstance(ret, Iterable)
420            rows = list(ret)
421            self.assertEqual(len(rows), 3)
422            rows = b''.join(rows)
423            self.assertIsInstance(rows, bytes)
424            self.assertEqual(rows, self.data_text.encode('utf-8'))
425
426    if str is not unicode:
427
428        def test_generator_unicode(self):
429            ret = self.copy_to(decode=True)
430            self.assertIsInstance(ret, Iterable)
431            rows = list(ret)
432            self.assertEqual(len(rows), 3)
433            rows = ''.join(rows)
434            self.assertIsInstance(rows, unicode)
435            self.assertEqual(rows, self.data_text.decode('utf-8'))
436
437    def test_rowcount_increment(self):
438        ret = self.copy_to()
439        self.assertIsInstance(ret, Iterable)
440        for n, row in enumerate(ret):
441            self.check_rowcount(n + 1)
442
443    def test_decode(self):
444        ret_raw = b''.join(self.copy_to(decode=False))
445        ret_decoded = ''.join(self.copy_to(decode=True))
446        self.assertIsInstance(ret_raw, bytes)
447        self.assertIsInstance(ret_decoded, unicode)
448        self.assertEqual(ret_decoded, ret_raw.decode('utf-8'))
449        self.check_rowcount()
450
451    def test_sep(self):
452        ret = list(self.copy_to(sep='-'))
453        self.assertEqual(ret, ['%d-%s\n' % row for row in self.data])
454
455    def test_null(self):
456        data = ['%d\t%s\n' % row for row in self.data]
457        self.cursor.execute('insert into copytest values(4, null)')
458        try:
459            ret = list(self.copy_to())
460            self.assertEqual(ret, data + ['4\t\\N\n'])
461            ret = list(self.copy_to(null='Nix'))
462            self.assertEqual(ret, data + ['4\tNix\n'])
463            ret = list(self.copy_to(null=''))
464            self.assertEqual(ret, data + ['4\t\n'])
465        finally:
466            self.cursor.execute('delete from copytest where id=4')
467
468    def test_columns(self):
469        data_id = ''.join('%d\n' % row[0] for row in self.data)
470        data_name = ''.join('%s\n' % row[1] for row in self.data)
471        ret = ''.join(self.copy_to(columns='id'))
472        self.assertEqual(ret, data_id)
473        ret = ''.join(self.copy_to(columns=['id']))
474        self.assertEqual(ret, data_id)
475        ret = ''.join(self.copy_to(columns='name'))
476        self.assertEqual(ret, data_name)
477        ret = ''.join(self.copy_to(columns=['name']))
478        self.assertEqual(ret, data_name)
479        ret = ''.join(self.copy_to(columns='id, name'))
480        self.assertEqual(ret, self.data_text)
481        ret = ''.join(self.copy_to(columns=['id', 'name']))
482        self.assertEqual(ret, self.data_text)
483        self.assertRaises(pgdb.ProgrammingError, self.copy_to,
484            columns=['id', 'age'])
485
486    def test_csv(self):
487        ret = self.copy_to(format='csv')
488        self.assertIsInstance(ret, Iterable)
489        rows = list(ret)
490        self.assertEqual(len(rows), 3)
491        rows = ''.join(rows)
492        self.assertIsInstance(rows, str)
493        self.assertEqual(rows, self.data_csv)
494        self.check_rowcount(3)
495
496    def test_csv_with_sep(self):
497        rows = ''.join(self.copy_to(format='csv', sep=';'))
498        self.assertEqual(rows, self.data_csv.replace(',', ';'))
499
500    def test_binary(self):
501        ret = self.copy_to(format='binary')
502        self.assertIsInstance(ret, Iterable)
503        for row in ret:
504            self.assertTrue(row.startswith(b'PGCOPY\n\377\r\n\0'))
505            break
506        self.check_rowcount(1)
507
508    def test_binary_with_sep(self):
509        self.assertRaises(ValueError, self.copy_to, format='binary', sep='\t')
510
511    def test_binary_with_unicode(self):
512        self.assertRaises(ValueError, self.copy_to,
513            format='binary', decode=True)
514
515    def test_query(self):
516        ret = self.cursor.copy_to(None,
517            "select name||'!' from copytest where id=1941")
518        self.assertIsInstance(ret, Iterable)
519        rows = list(ret)
520        self.assertEqual(len(rows), 1)
521        self.assertIsInstance(rows[0], str)
522        self.assertEqual(rows[0], '%s!\n' % self.data[1][1])
523        self.check_rowcount(1)
524
525    def test_file(self):
526        stream = self.data_file
527        ret = self.copy_to(stream)
528        self.assertIs(ret, self.cursor)
529        self.assertEqual(str(stream), self.data_text)
530        data = self.data_text
531        if str is unicode:
532            data = data.encode('utf-8')
533        sizes = [len(row) + 1 for row in data.splitlines()]
534        self.assertEqual(stream.sizes, sizes)
535        self.check_rowcount()
536
537
538class TestBinary(TestCopy):
539    """Test the copy_from and copy_to methods with binary data."""
540
541    def test_round_trip(self):
542        # fill table from textual data
543        self.cursor.copy_from(self.data_text, 'copytest', format='text')
544        self.check_table()
545        self.check_rowcount()
546        # get data back in binary format
547        ret = self.cursor.copy_to(None, 'copytest', format='binary')
548        self.assertIsInstance(ret, Iterable)
549        data_binary = b''.join(ret)
550        self.assertTrue(data_binary.startswith(b'PGCOPY\n\377\r\n\0'))
551        self.check_rowcount()
552        self.truncate_table()
553        # fill table from binary data
554        self.cursor.copy_from(data_binary, 'copytest', format='binary')
555        self.check_table()
556        self.check_rowcount()
557
558
559if __name__ == '__main__':
560    unittest.main()
Note: See TracBrowser for help on using the repository browser.