source: trunk/tests/test_dbapi20_copy.py @ 842

Last change on this file since 842 was 823, checked in by cito, 4 years ago

Raise the proper subclasses of DatabaseError?

Particularly, we raise IntegrityError? instead of ProgrammingError? for
duplicate keys. This also makes PyGreSQL more useable with SQLAlchemy.

  • Property svn:keywords set to Id
File size: 19.7 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the modern PyGreSQL interface.
5
6Sub-tests for the copy methods.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17
18from collections import Iterable
19
20import pgdb  # the module under test
21
22# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
23# get our information from that.  Otherwise we use the defaults.
24# The current user must have create schema privilege on the database.
25dbname = 'unittest'
26dbhost = None
27dbport = 5432
28
29try:
30    from .LOCAL_PyGreSQL import *
31except (ImportError, ValueError):
32    try:
33        from LOCAL_PyGreSQL import *
34    except ImportError:
35        pass
36
37try:
38    unicode
39except NameError:  # Python >= 3.0
40    unicode = str
41
42
43class InputStream:
44
45    def __init__(self, data):
46        if isinstance(data, unicode):
47            data = data.encode('utf-8')
48        self.data = data or b''
49        self.sizes = []
50
51    def __str__(self):
52        data = self.data
53        if str is unicode:  # Python >= 3.0
54            data = data.decode('utf-8')
55        return data
56
57    def __len__(self):
58        return len(self.data)
59
60    def read(self, size=None):
61        if size is None:
62            output, data = self.data, b''
63        else:
64            output, data = self.data[:size], self.data[size:]
65        self.data = data
66        self.sizes.append(size)
67        return output
68
69
70class OutputStream:
71
72    def __init__(self):
73        self.data = b''
74        self.sizes = []
75
76    def __str__(self):
77        data = self.data
78        if str is unicode:  # Python >= 3.0
79            data = data.decode('utf-8')
80        return data
81
82    def __len__(self):
83        return len(self.data)
84
85    def write(self, data):
86        if isinstance(data, unicode):
87            data = data.encode('utf-8')
88        self.data += data
89        self.sizes.append(len(data))
90
91
92class TestStreams(unittest.TestCase):
93
94    def test_input(self):
95        stream = InputStream('Hello, Wörld!')
96        self.assertIsInstance(stream.data, bytes)
97        self.assertEqual(stream.data, b'Hello, W\xc3\xb6rld!')
98        self.assertIsInstance(str(stream), str)
99        self.assertEqual(str(stream), 'Hello, Wörld!')
100        self.assertEqual(len(stream), 14)
101        self.assertEqual(stream.read(3), b'Hel')
102        self.assertEqual(stream.read(2), b'lo')
103        self.assertEqual(stream.read(1), b',')
104        self.assertEqual(stream.read(1), b' ')
105        self.assertEqual(stream.read(), b'W\xc3\xb6rld!')
106        self.assertEqual(stream.read(), b'')
107        self.assertEqual(len(stream), 0)
108        self.assertEqual(stream.sizes, [3, 2, 1, 1, None, None])
109
110    def test_output(self):
111        stream = OutputStream()
112        self.assertEqual(len(stream), 0)
113        for chunk in 'Hel', 'lo', ',', ' ', 'Wörld!':
114            stream.write(chunk)
115        self.assertIsInstance(stream.data, bytes)
116        self.assertEqual(stream.data, b'Hello, W\xc3\xb6rld!')
117        self.assertIsInstance(str(stream), str)
118        self.assertEqual(str(stream), 'Hello, Wörld!')
119        self.assertEqual(len(stream), 14)
120        self.assertEqual(stream.sizes, [3, 2, 1, 1, 7])
121
122
123class TestCopy(unittest.TestCase):
124
125    cls_set_up = False
126
127    @staticmethod
128    def connect():
129        return pgdb.connect(database=dbname,
130            host='%s:%d' % (dbhost or '', dbport or -1))
131
132    @classmethod
133    def setUpClass(cls):
134        con = cls.connect()
135        cur = con.cursor()
136        cur.execute("set client_min_messages=warning")
137        cur.execute("drop table if exists copytest cascade")
138        cur.execute("create table copytest ("
139            "id smallint primary key, name varchar(64))")
140        cur.close()
141        con.commit()
142        cur = con.cursor()
143        try:
144            cur.execute("set client_encoding=utf8")
145            cur.execute("select 'Plácido and José'").fetchone()
146        except (pgdb.DataError, pgdb.NotSupportedError):
147            cls.data[1] = (1941, 'Plaacido Domingo')
148            cls.data[2] = (1946, 'Josee Carreras')
149            cls.can_encode = False
150        cur.close()
151        con.close()
152        cls.cls_set_up = True
153
154    @classmethod
155    def tearDownClass(cls):
156        con = cls.connect()
157        cur = con.cursor()
158        cur.execute("set client_min_messages=warning")
159        cur.execute("drop table if exists copytest cascade")
160        con.commit()
161        con.close()
162
163    def setUp(self):
164        self.assertTrue(self.cls_set_up)
165        self.con = self.connect()
166        self.cursor = self.con.cursor()
167        self.cursor.execute("set client_encoding=utf8")
168
169    def tearDown(self):
170        try:
171            self.cursor.close()
172        except Exception:
173            pass
174        try:
175            self.con.rollback()
176        except Exception:
177            pass
178        try:
179            self.con.close()
180        except Exception:
181            pass
182
183    data = [(1935, 'Luciano Pavarotti'),
184            (1941, 'Plácido Domingo'),
185            (1946, 'José Carreras')]
186
187    can_encode = True
188
189    @property
190    def data_text(self):
191        return ''.join('%d\t%s\n' % row for row in self.data)
192
193    @property
194    def data_csv(self):
195        return ''.join('%d,%s\n' % row for row in self.data)
196
197    def truncate_table(self):
198        self.cursor.execute("truncate table copytest")
199
200    @property
201    def table_data(self):
202        self.cursor.execute("select * from copytest")
203        return self.cursor.fetchall()
204
205    def check_table(self):
206        self.assertEqual(self.table_data, self.data)
207
208    def check_rowcount(self, number=len(data)):
209        self.assertEqual(self.cursor.rowcount, number)
210
211
212class TestCopyFrom(TestCopy):
213    """Test the copy_from method."""
214
215    def tearDown(self):
216        super(TestCopyFrom, self).tearDown()
217        self.setUp()
218        self.truncate_table()
219        super(TestCopyFrom, self).tearDown()
220
221    def copy_from(self, stream, **options):
222        return self.cursor.copy_from(stream, 'copytest', **options)
223
224    @property
225    def data_file(self):
226        return InputStream(self.data_text)
227
228    def test_bad_params(self):
229        call = self.cursor.copy_from
230        call('0\t', 'copytest'), self.cursor
231        call('1\t', 'copytest',
232             format='text', sep='\t', null='', columns=['id', 'name'])
233        self.assertRaises(TypeError, call)
234        self.assertRaises(TypeError, call, None)
235        self.assertRaises(TypeError, call, None, None)
236        self.assertRaises(TypeError, call, '0\t')
237        self.assertRaises(TypeError, call, '0\t', None)
238        self.assertRaises(TypeError, call, '0\t', 42)
239        self.assertRaises(TypeError, call, '0\t', ['copytest'])
240        self.assertRaises(TypeError, call, '0\t', 'copytest', format=42)
241        self.assertRaises(ValueError, call, '0\t', 'copytest', format='bad')
242        self.assertRaises(TypeError, call, '0\t', 'copytest', sep=42)
243        self.assertRaises(ValueError, call, '0\t', 'copytest', sep='bad')
244        self.assertRaises(TypeError, call, '0\t', 'copytest', null=42)
245        self.assertRaises(ValueError, call, '0\t', 'copytest', size='bad')
246        self.assertRaises(TypeError, call, '0\t', 'copytest', columns=42)
247        self.assertRaises(ValueError, call, b'', 'copytest',
248            format='binary', sep=',')
249
250    def test_input_string(self):
251        ret = self.copy_from('42\tHello, world!')
252        self.assertIs(ret, self.cursor)
253        self.assertEqual(self.table_data, [(42, 'Hello, world!')])
254        self.check_rowcount(1)
255
256    def test_input_string_with_newline(self):
257        self.copy_from('42\tHello, world!\n')
258        self.assertEqual(self.table_data, [(42, 'Hello, world!')])
259        self.check_rowcount(1)
260
261    def test_input_string_multiple_rows(self):
262        ret = self.copy_from(self.data_text)
263        self.assertIs(ret, self.cursor)
264        self.check_table()
265        self.check_rowcount()
266
267    if str is unicode:  # Python >= 3.0
268
269        def test_input_bytes(self):
270            self.copy_from(b'42\tHello, world!')
271            self.assertEqual(self.table_data, [(42, 'Hello, world!')])
272            self.truncate_table()
273            self.copy_from(self.data_text.encode('utf-8'))
274            self.check_table()
275
276    else:  # Python < 3.0
277
278        def test_input_unicode(self):
279            if not self.can_encode:
280                self.skipTest('database does not support utf8')
281            self.copy_from(u'43\tWÃŒrstel, KÀse!')
282            self.assertEqual(self.table_data, [(43, 'WÃŒrstel, KÀse!')])
283            self.truncate_table()
284            self.copy_from(self.data_text.decode('utf-8'))
285            self.check_table()
286
287    def test_input_iterable(self):
288        self.copy_from(self.data_text.splitlines())
289        self.check_table()
290        self.check_rowcount()
291
292    def test_input_iterable_invalid(self):
293        self.assertRaises(IOError, self.copy_from, [None])
294
295    def test_input_iterable_with_newlines(self):
296        self.copy_from('%s\n' % row for row in self.data_text.splitlines())
297        self.check_table()
298
299    if str is unicode:  # Python >= 3.0
300
301        def test_input_iterable_bytes(self):
302            self.copy_from(row.encode('utf-8')
303                for row in self.data_text.splitlines())
304            self.check_table()
305
306    def test_sep(self):
307        stream = ('%d-%s' % row for row in self.data)
308        self.copy_from(stream, sep='-')
309        self.check_table()
310
311    def test_null(self):
312        self.copy_from('0\t\\N')
313        self.assertEqual(self.table_data, [(0, None)])
314        self.assertIsNone(self.table_data[0][1])
315        self.truncate_table()
316        self.copy_from('1\tNix')
317        self.assertEqual(self.table_data, [(1, 'Nix')])
318        self.assertIsNotNone(self.table_data[0][1])
319        self.truncate_table()
320        self.copy_from('2\tNix', null='Nix')
321        self.assertEqual(self.table_data, [(2, None)])
322        self.assertIsNone(self.table_data[0][1])
323        self.truncate_table()
324        self.copy_from('3\t')
325        self.assertEqual(self.table_data, [(3, '')])
326        self.assertIsNotNone(self.table_data[0][1])
327        self.truncate_table()
328        self.copy_from('4\t', null='')
329        self.assertEqual(self.table_data, [(4, None)])
330        self.assertIsNone(self.table_data[0][1])
331
332    def test_columns(self):
333        self.copy_from('1', columns='id')
334        self.copy_from('2', columns=['id'])
335        self.copy_from('3\tThree')
336        self.copy_from('4\tFour', columns='id, name')
337        self.copy_from('5\tFive', columns=['id', 'name'])
338        self.assertEqual(self.table_data, [
339            (1, None), (2, None), (3, 'Three'), (4, 'Four'), (5, 'Five')])
340        self.check_rowcount(5)
341        self.assertRaises(pgdb.ProgrammingError, self.copy_from,
342            '6\t42', columns=['id', 'age'])
343        self.check_rowcount(-1)
344
345    def test_csv(self):
346        self.copy_from(self.data_csv, format='csv')
347        self.check_table()
348
349    def test_csv_with_sep(self):
350        stream = ('%d;"%s"\n' % row for row in self.data)
351        self.copy_from(stream, format='csv', sep=';')
352        self.check_table()
353        self.check_rowcount()
354
355    def test_binary(self):
356        self.assertRaises(IOError, self.copy_from,
357            b'NOPGCOPY\n', format='binary')
358        self.check_rowcount(-1)
359
360    def test_binary_with_sep(self):
361        self.assertRaises(ValueError, self.copy_from,
362            '', format='binary', sep='\t')
363
364    def test_binary_with_unicode(self):
365        self.assertRaises(ValueError, self.copy_from, u'', format='binary')
366
367    def test_query(self):
368        self.assertRaises(ValueError, self.cursor.copy_from, '', "select null")
369
370    def test_file(self):
371        stream = self.data_file
372        ret = self.copy_from(stream)
373        self.assertIs(ret, self.cursor)
374        self.check_table()
375        self.assertEqual(len(stream), 0)
376        self.assertEqual(stream.sizes, [8192])
377        self.check_rowcount()
378
379    def test_size_positive(self):
380        stream = self.data_file
381        size = 7
382        num_chunks = (len(stream) + size - 1) // size
383        self.copy_from(stream, size=size)
384        self.check_table()
385        self.assertEqual(len(stream), 0)
386        self.assertEqual(stream.sizes, [size] * num_chunks)
387        self.check_rowcount()
388
389    def test_size_negative(self):
390        stream = self.data_file
391        self.copy_from(stream, size=-1)
392        self.check_table()
393        self.assertEqual(len(stream), 0)
394        self.assertEqual(stream.sizes, [None])
395        self.check_rowcount()
396
397    def test_size_invalid(self):
398        self.assertRaises(TypeError,
399            self.copy_from, self.data_file, size='invalid')
400
401
402class TestCopyTo(TestCopy):
403    """Test the copy_to method."""
404
405    @classmethod
406    def setUpClass(cls):
407        super(TestCopyTo, cls).setUpClass()
408        con = cls.connect()
409        cur = con.cursor()
410        cur.execute("set client_encoding=utf8")
411        cur.execute("insert into copytest values (%d, %s)", cls.data)
412        cur.close()
413        con.commit()
414        con.close()
415
416    def copy_to(self, stream=None, **options):
417        return self.cursor.copy_to(stream, 'copytest', **options)
418
419    @property
420    def data_file(self):
421        return OutputStream()
422
423    def test_bad_params(self):
424        call = self.cursor.copy_to
425        call(None, 'copytest')
426        call(None, 'copytest',
427             format='text', sep='\t', null='', columns=['id', 'name'])
428        self.assertRaises(TypeError, call)
429        self.assertRaises(TypeError, call, None)
430        self.assertRaises(TypeError, call, None, 42)
431        self.assertRaises(TypeError, call, None, ['copytest'])
432        self.assertRaises(TypeError, call, 'bad', 'copytest')
433        self.assertRaises(TypeError, call, None, 'copytest', format=42)
434        self.assertRaises(ValueError, call, None, 'copytest', format='bad')
435        self.assertRaises(TypeError, call, None, 'copytest', sep=42)
436        self.assertRaises(ValueError, call, None, 'copytest', sep='bad')
437        self.assertRaises(TypeError, call, None, 'copytest', null=42)
438        self.assertRaises(TypeError, call, None, 'copytest', decode='bad')
439        self.assertRaises(TypeError, call, None, 'copytest', columns=42)
440
441    def test_generator(self):
442        ret = self.copy_to()
443        self.assertIsInstance(ret, Iterable)
444        rows = list(ret)
445        self.assertEqual(len(rows), 3)
446        rows = ''.join(rows)
447        self.assertIsInstance(rows, str)
448        self.assertEqual(rows, self.data_text)
449        self.check_rowcount()
450
451    if str is unicode:  # Python >= 3.0
452
453        def test_generator_bytes(self):
454            ret = self.copy_to(decode=False)
455            self.assertIsInstance(ret, Iterable)
456            rows = list(ret)
457            self.assertEqual(len(rows), 3)
458            rows = b''.join(rows)
459            self.assertIsInstance(rows, bytes)
460            self.assertEqual(rows, self.data_text.encode('utf-8'))
461
462    else:  # Python < 3.0
463
464        def test_generator_unicode(self):
465            ret = self.copy_to(decode=True)
466            self.assertIsInstance(ret, Iterable)
467            rows = list(ret)
468            self.assertEqual(len(rows), 3)
469            rows = ''.join(rows)
470            self.assertIsInstance(rows, unicode)
471            self.assertEqual(rows, self.data_text.decode('utf-8'))
472
473    def test_rowcount_increment(self):
474        ret = self.copy_to()
475        self.assertIsInstance(ret, Iterable)
476        for n, row in enumerate(ret):
477            self.check_rowcount(n + 1)
478
479    def test_decode(self):
480        ret_raw = b''.join(self.copy_to(decode=False))
481        ret_decoded = ''.join(self.copy_to(decode=True))
482        self.assertIsInstance(ret_raw, bytes)
483        self.assertIsInstance(ret_decoded, unicode)
484        self.assertEqual(ret_decoded, ret_raw.decode('utf-8'))
485        self.check_rowcount()
486
487    def test_sep(self):
488        ret = list(self.copy_to(sep='-'))
489        self.assertEqual(ret, ['%d-%s\n' % row for row in self.data])
490
491    def test_null(self):
492        data = ['%d\t%s\n' % row for row in self.data]
493        self.cursor.execute('insert into copytest values(4, null)')
494        try:
495            ret = list(self.copy_to())
496            self.assertEqual(ret, data + ['4\t\\N\n'])
497            ret = list(self.copy_to(null='Nix'))
498            self.assertEqual(ret, data + ['4\tNix\n'])
499            ret = list(self.copy_to(null=''))
500            self.assertEqual(ret, data + ['4\t\n'])
501        finally:
502            self.cursor.execute('delete from copytest where id=4')
503
504    def test_columns(self):
505        data_id = ''.join('%d\n' % row[0] for row in self.data)
506        data_name = ''.join('%s\n' % row[1] for row in self.data)
507        ret = ''.join(self.copy_to(columns='id'))
508        self.assertEqual(ret, data_id)
509        ret = ''.join(self.copy_to(columns=['id']))
510        self.assertEqual(ret, data_id)
511        ret = ''.join(self.copy_to(columns='name'))
512        self.assertEqual(ret, data_name)
513        ret = ''.join(self.copy_to(columns=['name']))
514        self.assertEqual(ret, data_name)
515        ret = ''.join(self.copy_to(columns='id, name'))
516        self.assertEqual(ret, self.data_text)
517        ret = ''.join(self.copy_to(columns=['id', 'name']))
518        self.assertEqual(ret, self.data_text)
519        self.assertRaises(pgdb.ProgrammingError, self.copy_to,
520            columns=['id', 'age'])
521
522    def test_csv(self):
523        ret = self.copy_to(format='csv')
524        self.assertIsInstance(ret, Iterable)
525        rows = list(ret)
526        self.assertEqual(len(rows), 3)
527        rows = ''.join(rows)
528        self.assertIsInstance(rows, str)
529        self.assertEqual(rows, self.data_csv)
530        self.check_rowcount(3)
531
532    def test_csv_with_sep(self):
533        rows = ''.join(self.copy_to(format='csv', sep=';'))
534        self.assertEqual(rows, self.data_csv.replace(',', ';'))
535
536    def test_binary(self):
537        ret = self.copy_to(format='binary')
538        self.assertIsInstance(ret, Iterable)
539        for row in ret:
540            self.assertTrue(row.startswith(b'PGCOPY\n\377\r\n\0'))
541            break
542        self.check_rowcount(1)
543
544    def test_binary_with_sep(self):
545        self.assertRaises(ValueError, self.copy_to, format='binary', sep='\t')
546
547    def test_binary_with_unicode(self):
548        self.assertRaises(ValueError, self.copy_to,
549            format='binary', decode=True)
550
551    def test_query(self):
552        self.assertRaises(ValueError, self.cursor.copy_to, None,
553            "select name from copytest", columns='noname')
554        ret = self.cursor.copy_to(None,
555            "select name||'!' from copytest where id=1941")
556        self.assertIsInstance(ret, Iterable)
557        rows = list(ret)
558        self.assertEqual(len(rows), 1)
559        self.assertIsInstance(rows[0], str)
560        self.assertEqual(rows[0], '%s!\n' % self.data[1][1])
561        self.check_rowcount(1)
562
563    def test_file(self):
564        stream = self.data_file
565        ret = self.copy_to(stream)
566        self.assertIs(ret, self.cursor)
567        self.assertEqual(str(stream), self.data_text)
568        data = self.data_text
569        if str is unicode:  # Python >= 3.0
570            data = data.encode('utf-8')
571        sizes = [len(row) + 1 for row in data.splitlines()]
572        self.assertEqual(stream.sizes, sizes)
573        self.check_rowcount()
574
575
576class TestBinary(TestCopy):
577    """Test the copy_from and copy_to methods with binary data."""
578
579    def test_round_trip(self):
580        # fill table from textual data
581        self.cursor.copy_from(self.data_text, 'copytest', format='text')
582        self.check_table()
583        self.check_rowcount()
584        # get data back in binary format
585        ret = self.cursor.copy_to(None, 'copytest', format='binary')
586        self.assertIsInstance(ret, Iterable)
587        data_binary = b''.join(ret)
588        self.assertTrue(data_binary.startswith(b'PGCOPY\n\377\r\n\0'))
589        self.check_rowcount()
590        self.truncate_table()
591        # fill table from binary data
592        self.cursor.copy_from(data_binary, 'copytest', format='binary')
593        self.check_table()
594        self.check_rowcount()
595
596
597if __name__ == '__main__':
598    unittest.main()
Note: See TracBrowser for help on using the repository browser.