source: trunk/tests/test_classic.py @ 765

Last change on this file since 765 was 762, checked in by cito, 4 years ago

Docs and 100% test coverage for NotificationHandler?

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 11.6 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4from __future__ import print_function
5
6try:
7    import unittest2 as unittest  # for Python < 2.7
8except ImportError:
9    import unittest
10
11import sys
12from functools import partial
13from time import sleep
14from threading import Thread
15
16from pg import *
17
18# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
19# get our information from that.  Otherwise we use the defaults.
20dbname = 'unittest'
21dbhost = None
22dbport = 5432
23
24try:
25    from .LOCAL_PyGreSQL import *
26except (ImportError, ValueError):
27    try:
28        from LOCAL_PyGreSQL import *
29    except ImportError:
30        pass
31
32
33def opendb():
34    db = DB(dbname, dbhost, dbport)
35    db.query("SET DATESTYLE TO 'ISO'")
36    db.query("SET TIME ZONE 'EST5EDT'")
37    db.query("SET DEFAULT_WITH_OIDS=FALSE")
38    db.query("SET STANDARD_CONFORMING_STRINGS=FALSE")
39    return db
40
41db = opendb()
42for q in (
43    "DROP TABLE _test1._test_schema",
44    "DROP TABLE _test2._test_schema",
45    "DROP SCHEMA _test1",
46    "DROP SCHEMA _test2",
47):
48    try:
49        db.query(q)
50    except Exception:
51        pass
52db.close()
53
54
55class UtilityTest(unittest.TestCase):
56
57    def setUp(self):
58        """Setup test tables or empty them if they already exist."""
59        db = opendb()
60
61        for t in ('_test1', '_test2'):
62            try:
63                db.query("CREATE SCHEMA " + t)
64            except Error:
65                pass
66            try:
67                db.query("CREATE TABLE %s._test_schema "
68                    "(%s int PRIMARY KEY)" % (t, t))
69            except Error:
70                db.query("DELETE FROM %s._test_schema" % t)
71        try:
72            db.query("CREATE TABLE _test_schema "
73                "(_test int PRIMARY KEY, _i interval, dvar int DEFAULT 999)")
74        except Error:
75            db.query("DELETE FROM _test_schema")
76        try:
77            db.query("CREATE VIEW _test_vschema AS "
78                "SELECT _test, 'abc'::text AS _test2 FROM _test_schema")
79        except Error:
80            pass
81
82    def test_invalidname(self):
83        """Make sure that invalid table names are caught"""
84        db = opendb()
85        self.assertRaises(ProgrammingError, db.get_attnames, 'x.y.z')
86
87    def test_schema(self):
88        """Does it differentiate the same table name in different schemas"""
89        db = opendb()
90        # see if they differentiate the table names properly
91        self.assertEqual(
92            db.get_attnames('_test_schema'),
93            {'_test': 'int', '_i': 'date', 'dvar': 'int'}
94        )
95        self.assertEqual(
96            db.get_attnames('public._test_schema'),
97            {'_test': 'int', '_i': 'date', 'dvar': 'int'}
98        )
99        self.assertEqual(
100            db.get_attnames('_test1._test_schema'),
101            {'_test1': 'int'}
102        )
103        self.assertEqual(
104            db.get_attnames('_test2._test_schema'),
105            {'_test2': 'int'}
106        )
107
108    def test_pkey(self):
109        db = opendb()
110        self.assertEqual(db.pkey('_test_schema'), '_test')
111        self.assertEqual(db.pkey('public._test_schema'), '_test')
112        self.assertEqual(db.pkey('_test1._test_schema'), '_test1')
113        self.assertEqual(db.pkey('_test2._test_schema'), '_test2')
114        self.assertRaises(KeyError, db.pkey, '_test_vschema')
115
116    def test_get(self):
117        db = opendb()
118        db.query("INSERT INTO _test_schema VALUES (1234)")
119        db.get('_test_schema', 1234)
120        db.get('_test_schema', 1234, keyname='_test')
121        self.assertRaises(ProgrammingError, db.get, '_test_vschema', 1234)
122        db.get('_test_vschema', 1234, keyname='_test')
123
124    def test_params(self):
125        db = opendb()
126        db.query("INSERT INTO _test_schema VALUES ($1, $2, $3)", 12, None, 34)
127        d = db.get('_test_schema', 12)
128        self.assertEqual(d['dvar'], 34)
129
130    def test_insert(self):
131        db = opendb()
132        d = dict(_test=1234)
133        db.insert('_test_schema', d)
134        self.assertEqual(d['dvar'], 999)
135        db.insert('_test_schema', _test=1235)
136        self.assertEqual(d['dvar'], 999)
137
138    def test_context_manager(self):
139        db = opendb()
140        t = '_test_schema'
141        d = dict(_test=1235)
142        with db:
143            db.insert(t, d)
144            d['_test'] += 1
145            db.insert(t, d)
146        try:
147            with db:
148                d['_test'] += 1
149                db.insert(t, d)
150                db.insert(t, d)
151        except ProgrammingError:
152            pass
153        with db:
154            d['_test'] += 1
155            db.insert(t, d)
156            d['_test'] += 1
157            db.insert(t, d)
158        self.assertTrue(db.get(t, 1235))
159        self.assertTrue(db.get(t, 1236))
160        self.assertRaises(DatabaseError, db.get, t, 1237)
161        self.assertTrue(db.get(t, 1238))
162        self.assertTrue(db.get(t, 1239))
163
164    def test_sqlstate(self):
165        db = opendb()
166        db.query("INSERT INTO _test_schema VALUES (1234)")
167        try:
168            db.query("INSERT INTO _test_schema VALUES (1234)")
169        except DatabaseError as error:
170            # currently PyGreSQL does not support IntegrityError
171            self.assertTrue(isinstance(error, ProgrammingError))
172            # the SQLSTATE error code for unique violation is 23505
173            self.assertEqual(error.sqlstate, '23505')
174
175    def test_mixed_case(self):
176        db = opendb()
177        try:
178            db.query('CREATE TABLE _test_mc ("_Test" int PRIMARY KEY)')
179        except Error:
180            db.query("DELETE FROM _test_mc")
181        d = dict(_Test=1234)
182        db.insert('_test_mc', d)
183
184    def test_update(self):
185        db = opendb()
186        db.query("INSERT INTO _test_schema VALUES (1234)")
187
188        r = db.get('_test_schema', 1234)
189        r['dvar'] = 123
190        db.update('_test_schema', r)
191        r = db.get('_test_schema', 1234)
192        self.assertEqual(r['dvar'], 123)
193
194        r = db.get('_test_schema', 1234)
195        self.assertIn('dvar', r)
196        db.update('_test_schema', _test=1234, dvar=456)
197        r = db.get('_test_schema', 1234)
198        self.assertEqual(r['dvar'], 456)
199
200        r = db.get('_test_schema', 1234)
201        db.update('_test_schema', r, dvar=456)
202        r = db.get('_test_schema', 1234)
203        self.assertEqual(r['dvar'], 456)
204
205    def notify_callback(self, arg_dict):
206        if arg_dict:
207            arg_dict['called'] = True
208        else:
209            self.notify_timeout = True
210
211    def test_notify(self, options=None):
212        if not options:
213            options = {}
214        run_as_method = options.get('run_as_method')
215        call_notify = options.get('call_notify')
216        two_payloads = options.get('two_payloads')
217        db = opendb()
218        # Get function under test, can be standalone or DB method.
219        fut = db.notification_handler if run_as_method else partial(
220            NotificationHandler, db)
221        arg_dict = dict(event=None, called=False)
222        self.notify_timeout = False
223        # Listen for 'event_1'.
224        target = fut('event_1', self.notify_callback, arg_dict, 5)
225        thread = Thread(None, target)
226        thread.start()
227        try:
228            # Wait until the thread has started.
229            for n in range(500):
230                if target.listening:
231                    break
232                sleep(0.01)
233            self.assertTrue(target.listening)
234            self.assertTrue(thread.is_alive())
235            # Open another connection for sending notifications.
236            db2 = opendb()
237            # Generate notification from the other connection.
238            if two_payloads:
239                db2.begin()
240            if call_notify:
241                if two_payloads:
242                    target.notify(db2, payload='payload 0')
243                target.notify(db2, payload='payload 1')
244            else:
245                if two_payloads:
246                    db2.query("notify event_1, 'payload 0'")
247                db2.query("notify event_1, 'payload 1'")
248            if two_payloads:
249                db2.commit()
250            # Wait until the notification has been caught.
251            for n in range(500):
252                if arg_dict['called'] or self.notify_timeout:
253                    break
254                sleep(0.01)
255            # Check that callback has been invoked.
256            self.assertTrue(arg_dict['called'])
257            self.assertEqual(arg_dict['event'], 'event_1')
258            self.assertEqual(arg_dict['extra'], 'payload 1')
259            self.assertTrue(isinstance(arg_dict['pid'], int))
260            self.assertFalse(self.notify_timeout)
261            arg_dict['called'] = False
262            self.assertTrue(thread.is_alive())
263            # Generate stop notification.
264            if call_notify:
265                target.notify(db2, stop=True, payload='payload 2')
266            else:
267                db2.query("notify stop_event_1, 'payload 2'")
268            db2.close()
269            # Wait until the notification has been caught.
270            for n in range(500):
271                if arg_dict['called'] or self.notify_timeout:
272                    break
273                sleep(0.01)
274            # Check that callback has been invoked.
275            self.assertTrue(arg_dict['called'])
276            self.assertEqual(arg_dict['event'], 'stop_event_1')
277            self.assertEqual(arg_dict['extra'], 'payload 2')
278            self.assertTrue(isinstance(arg_dict['pid'], int))
279            self.assertFalse(self.notify_timeout)
280            thread.join(5)
281            self.assertFalse(thread.is_alive())
282            self.assertFalse(target.listening)
283            target.close()
284        except Exception:
285            target.close()
286            if thread.is_alive():
287                thread.join(5)
288
289    def test_notify_other_options(self):
290        for run_as_method in False, True:
291            for call_notify in False, True:
292                for two_payloads in False, True:
293                    options = dict(
294                        run_as_method=run_as_method,
295                        call_notify=call_notify,
296                        two_payloads=two_payloads)
297                    if any(options.values()):
298                        self.test_notify(options)
299
300    def test_notify_timeout(self):
301        for run_as_method in False, True:
302            db = opendb()
303            # Get function under test, can be standalone or DB method.
304            fut = db.notification_handler if run_as_method else partial(
305                NotificationHandler, db)
306            arg_dict = dict(event=None, called=False)
307            self.notify_timeout = False
308            # Listen for 'event_1' with timeout of 10ms.
309            target = fut('event_1', self.notify_callback, arg_dict, 0.01)
310            thread = Thread(None, target)
311            thread.start()
312            # Sleep 20ms, long enough to time out.
313            sleep(0.02)
314            # Verify that we've indeed timed out.
315            self.assertFalse(arg_dict.get('called'))
316            self.assertTrue(self.notify_timeout)
317            self.assertFalse(thread.is_alive())
318            self.assertFalse(target.listening)
319            target.close()
320
321
322if __name__ == '__main__':
323    if len(sys.argv) == 2 and sys.argv[1] == '-l':
324        print('\n'.join(unittest.getTestCaseNames(UtilityTest, 'test_')))
325        sys.exit(0)
326
327    test_list = [name for name in sys.argv[1:] if not name.startswith('-')]
328    if not test_list:
329        test_list = unittest.getTestCaseNames(UtilityTest, 'test_')
330
331    suite = unittest.TestSuite()
332    for test_name in test_list:
333        try:
334            suite.addTest(UtilityTest(test_name))
335        except Exception:
336            print("\n ERROR: %s.\n" % sys.exc_value)
337            sys.exit(1)
338
339    verbosity = '-v' in sys.argv[1:] and 2 or 1
340    failfast = '-l' in sys.argv[1:]
341    runner = unittest.TextTestRunner(verbosity=verbosity, failfast=failfast)
342    rc = runner.run(suite)
343    sys.exit(1 if rc.errors or rc.failures else 0)
Note: See TracBrowser for help on using the repository browser.