source: trunk/tests/test_classic.py @ 730

Last change on this file since 730 was 730, checked in by cito, 4 years ago

Use query parameters instead of inline values

The single row methods of the DB wrapper class created queries with inline values
instead of passing them separately as parameters, even though our query method
does have this capability. Using query parameters also spares us a lot of quoting
and escaping that is necessary when passing values inline.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 11.6 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4from __future__ import print_function
5
6try:
7    import unittest2 as unittest  # for Python < 2.7
8except ImportError:
9    import unittest
10
11import sys
12from functools import partial
13from time import sleep
14from threading import Thread
15
16from pg import *
17
18# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
19# get our information from that.  Otherwise we use the defaults.
20dbname = 'unittest'
21dbhost = None
22dbport = 5432
23
24try:
25    from .LOCAL_PyGreSQL import *
26except (ImportError, ValueError):
27    try:
28        from LOCAL_PyGreSQL import *
29    except ImportError:
30        pass
31
32
33def opendb():
34    db = DB(dbname, dbhost, dbport)
35    db.query("SET DATESTYLE TO 'ISO'")
36    db.query("SET TIME ZONE 'EST5EDT'")
37    db.query("SET DEFAULT_WITH_OIDS=FALSE")
38    db.query("SET STANDARD_CONFORMING_STRINGS=FALSE")
39    return db
40
41db = opendb()
42for q in (
43    "DROP TABLE _test1._test_schema",
44    "DROP TABLE _test2._test_schema",
45    "DROP SCHEMA _test1",
46    "DROP SCHEMA _test2",
47):
48    try:
49        db.query(q)
50    except Exception:
51        pass
52db.close()
53
54
55class UtilityTest(unittest.TestCase):
56
57    def setUp(self):
58        """Setup test tables or empty them if they already exist."""
59        db = opendb()
60
61        for t in ('_test1', '_test2'):
62            try:
63                db.query("CREATE SCHEMA " + t)
64            except Error:
65                pass
66            try:
67                db.query("CREATE TABLE %s._test_schema "
68                    "(%s int PRIMARY KEY)" % (t, t))
69            except Error:
70                db.query("DELETE FROM %s._test_schema" % t)
71        try:
72            db.query("CREATE TABLE _test_schema "
73                "(_test int PRIMARY KEY, _i interval, dvar int DEFAULT 999)")
74        except Error:
75            db.query("DELETE FROM _test_schema")
76        try:
77            db.query("CREATE VIEW _test_vschema AS "
78                "SELECT _test, 'abc'::text AS _test2 FROM _test_schema")
79        except Error:
80            pass
81
82    def test_invalidname(self):
83        """Make sure that invalid table names are caught"""
84        db = opendb()
85        self.assertRaises(ProgrammingError, db.get_attnames, 'x.y.z')
86
87    def test_schema(self):
88        """Does it differentiate the same table name in different schemas"""
89        db = opendb()
90        # see if they differentiate the table names properly
91        self.assertEqual(
92            db.get_attnames('_test_schema'),
93            {'_test': 'int', '_i': 'date', 'dvar': 'int'}
94        )
95        self.assertEqual(
96            db.get_attnames('public._test_schema'),
97            {'_test': 'int', '_i': 'date', 'dvar': 'int'}
98        )
99        self.assertEqual(
100            db.get_attnames('_test1._test_schema'),
101            {'_test1': 'int'}
102        )
103        self.assertEqual(
104            db.get_attnames('_test2._test_schema'),
105            {'_test2': 'int'}
106        )
107
108    def test_pkey(self):
109        db = opendb()
110        self.assertEqual(db.pkey('_test_schema'), '_test')
111        self.assertEqual(db.pkey('public._test_schema'), '_test')
112        self.assertEqual(db.pkey('_test1._test_schema'), '_test1')
113        self.assertEqual(db.pkey('_test2._test_schema'), '_test2')
114        self.assertRaises(KeyError, db.pkey, '_test_vschema')
115
116    def test_get(self):
117        db = opendb()
118        db.query("INSERT INTO _test_schema VALUES (1234)")
119        db.get('_test_schema', 1234)
120        db.get('_test_schema', 1234, keyname='_test')
121        self.assertRaises(ProgrammingError, db.get, '_test_vschema', 1234)
122        db.get('_test_vschema', 1234, keyname='_test')
123
124    def test_params(self):
125        db = opendb()
126        db.query("INSERT INTO _test_schema VALUES ($1, $2, $3)", 12, None, 34)
127        d = db.get('_test_schema', 12)
128        self.assertEqual(d['dvar'], 34)
129
130    def test_insert(self):
131        db = opendb()
132        d = dict(_test=1234)
133        db.insert('_test_schema', d)
134        self.assertEqual(d['dvar'], 999)
135        db.insert('_test_schema', _test=1235)
136        self.assertEqual(d['dvar'], 999)
137
138    def test_context_manager(self):
139        db = opendb()
140        t = '_test_schema'
141        d = dict(_test=1235)
142        with db:
143            db.insert(t, d)
144            d['_test'] += 1
145            db.insert(t, d)
146        try:
147            with db:
148                d['_test'] += 1
149                db.insert(t, d)
150                db.insert(t, d)
151        except ProgrammingError:
152            pass
153        with db:
154            d['_test'] += 1
155            db.insert(t, d)
156            d['_test'] += 1
157            db.insert(t, d)
158        self.assertTrue(db.get(t, 1235))
159        self.assertTrue(db.get(t, 1236))
160        self.assertRaises(DatabaseError, db.get, t, 1237)
161        self.assertTrue(db.get(t, 1238))
162        self.assertTrue(db.get(t, 1239))
163
164    def test_sqlstate(self):
165        db = opendb()
166        db.query("INSERT INTO _test_schema VALUES (1234)")
167        try:
168            db.query("INSERT INTO _test_schema VALUES (1234)")
169        except DatabaseError as error:
170            # currently PyGreSQL does not support IntegrityError
171            self.assertTrue(isinstance(error, ProgrammingError))
172            # the SQLSTATE error code for unique violation is 23505
173            self.assertEqual(error.sqlstate, '23505')
174
175    def test_mixed_case(self):
176        db = opendb()
177        try:
178            db.query('CREATE TABLE _test_mc ("_Test" int PRIMARY KEY)')
179        except Error:
180            db.query("DELETE FROM _test_mc")
181        d = dict(_Test=1234)
182        db.insert('_test_mc', d)
183
184    def test_update(self):
185        db = opendb()
186        db.query("INSERT INTO _test_schema VALUES (1234)")
187
188        r = db.get('_test_schema', 1234)
189        r['dvar'] = 123
190        db.update('_test_schema', r)
191        r = db.get('_test_schema', 1234)
192        self.assertEqual(r['dvar'], 123)
193
194        r = db.get('_test_schema', 1234)
195        self.assertIn('dvar', r)
196        db.update('_test_schema', _test=1234, dvar=456)
197        r = db.get('_test_schema', 1234)
198        self.assertEqual(r['dvar'], 456)
199
200        r = db.get('_test_schema', 1234)
201        db.update('_test_schema', r, dvar=456)
202        r = db.get('_test_schema', 1234)
203        self.assertEqual(r['dvar'], 456)
204
205    def notify_callback(self, arg_dict):
206        if arg_dict:
207            arg_dict['called'] = True
208        else:
209            self.notify_timeout = True
210
211    def test_notify(self, options=None):
212        if not options:
213            options = {}
214        run_as_method = options.get('run_as_method')
215        call_notify = options.get('call_notify')
216        two_payloads = options.get('two_payloads')
217        db = opendb()
218        # Get function under test, can be standalone or DB method.
219        fut = db.notification_handler if run_as_method else partial(
220            NotificationHandler, db)
221        arg_dict = dict(event=None, called=False)
222        self.notify_timeout = False
223        # Listen for 'event_1'.
224        target = fut('event_1', self.notify_callback, arg_dict, 5)
225        thread = Thread(None, target)
226        thread.start()
227        try:
228            # Wait until the thread has started.
229            for n in range(500):
230                if target.listening:
231                    break
232                sleep(0.01)
233            self.assertTrue(target.listening)
234            self.assertTrue(thread.isAlive())
235            # Open another connection for sending notifications.
236            db2 = opendb()
237            # Generate notification from the other connection.
238            if two_payloads:
239                db2.begin()
240            if call_notify:
241                if two_payloads:
242                    target.notify(db2, payload='payload 0')
243                target.notify(db2, payload='payload 1')
244            else:
245                if two_payloads:
246                    db2.query("notify event_1, 'payload 0'")
247                db2.query("notify event_1, 'payload 1'")
248            if two_payloads:
249                db2.commit()
250            # Wait until the notification has been caught.
251            for n in range(500):
252                if arg_dict['called'] or self.notify_timeout:
253                    break
254                sleep(0.01)
255            # Check that callback has been invoked.
256            self.assertTrue(arg_dict['called'])
257            self.assertEqual(arg_dict['event'], 'event_1')
258            self.assertEqual(arg_dict['extra'], 'payload 1')
259            self.assertTrue(isinstance(arg_dict['pid'], int))
260            self.assertFalse(self.notify_timeout)
261            arg_dict['called'] = False
262            self.assertTrue(thread.isAlive())
263            # Generate stop notification.
264            if call_notify:
265                target.notify(db2, stop=True, payload='payload 2')
266            else:
267                db2.query("notify stop_event_1, 'payload 2'")
268            db2.close()
269            # Wait until the notification has been caught.
270            for n in range(500):
271                if arg_dict['called'] or self.notify_timeout:
272                    break
273                sleep(0.01)
274            # Check that callback has been invoked.
275            self.assertTrue(arg_dict['called'])
276            self.assertEqual(arg_dict['event'], 'stop_event_1')
277            self.assertEqual(arg_dict['extra'], 'payload 2')
278            self.assertTrue(isinstance(arg_dict['pid'], int))
279            self.assertFalse(self.notify_timeout)
280            thread.join(5)
281            self.assertFalse(thread.isAlive())
282            self.assertFalse(target.listening)
283            target.close()
284        except Exception:
285            target.close()
286            if thread.is_alive():
287                thread.join(5)
288
289    def test_notify_other_options(self):
290        for run_as_method in False, True:
291            for call_notify in False, True:
292                for two_payloads in False, True:
293                    options = dict(
294                        run_as_method=run_as_method,
295                        call_notify=call_notify,
296                        two_payloads=two_payloads)
297                    if any(options.values()):
298                        self.test_notify(options)
299
300    def test_notify_timeout(self):
301        for run_as_method in False, True:
302            db = opendb()
303            # Get function under test, can be standalone or DB method.
304            fut = db.notification_handler if run_as_method else partial(
305                NotificationHandler, db)
306            arg_dict = dict(event=None, called=False)
307            self.notify_timeout = False
308            # Listen for 'event_1' with timeout of 10ms.
309            target = fut('event_1', self.notify_callback, arg_dict, 0.01)
310            thread = Thread(None, target)
311            thread.start()
312            # Sleep 20ms, long enough to time out.
313            sleep(0.02)
314            # Verify that we've indeed timed out.
315            self.assertFalse(arg_dict.get('called'))
316            self.assertTrue(self.notify_timeout)
317            self.assertFalse(thread.isAlive())
318            self.assertFalse(target.listening)
319            target.close()
320
321
322if __name__ == '__main__':
323    if len(sys.argv) == 2 and sys.argv[1] == '-l':
324        print('\n'.join(unittest.getTestCaseNames(UtilityTest, 'test_')))
325        sys.exit(0)
326
327    test_list = [name for name in sys.argv[1:] if not name.startswith('-')]
328    if not test_list:
329        test_list = unittest.getTestCaseNames(UtilityTest, 'test_')
330
331    suite = unittest.TestSuite()
332    for test_name in test_list:
333        try:
334            suite.addTest(UtilityTest(test_name))
335        except Exception:
336            print("\n ERROR: %s.\n" % sys.exc_value)
337            sys.exit(1)
338
339    verbosity = '-v' in sys.argv[1:] and 2 or 1
340    failfast = '-l' in sys.argv[1:]
341    runner = unittest.TextTestRunner(verbosity=verbosity, failfast=failfast)
342    rc = runner.run(suite)
343    sys.exit(1 if rc.errors or rc.failures else 0)
Note: See TracBrowser for help on using the repository browser.