source: branches/4.x/module/tests/test_classic.py @ 647

Last change on this file since 647 was 647, checked in by cito, 4 years ago

Make tests compatible with Python 2.5

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 13.1 KB
Line 
1#! /usr/bin/python
2
3from __future__ import with_statement
4
5import sys
6from functools import partial
7from time import sleep
8from threading import Thread
9import unittest
10from pg import *
11
12# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
13# get our information from that.  Otherwise we use the defaults.
14dbname = 'unittest'
15dbhost = None
16dbport = 5432
17
18try:
19    from LOCAL_PyGreSQL import *
20except ImportError:
21    pass
22
23def opendb():
24    db = DB(dbname, dbhost, dbport)
25    db.query("SET DATESTYLE TO 'ISO'")
26    db.query("SET TIME ZONE 'EST5EDT'")
27    db.query("SET DEFAULT_WITH_OIDS=FALSE")
28    db.query("SET STANDARD_CONFORMING_STRINGS=FALSE")
29    return db
30
31
32class UtilityTest(unittest.TestCase):
33
34    def setUp(self):
35        """Setup test tables or empty them if they already exist."""
36        db = opendb()
37
38        for t in ('_test1', '_test2'):
39            try:
40                db.query("CREATE SCHEMA " + t)
41            except Error:
42                pass
43            try:
44                db.query("CREATE TABLE %s._test_schema "
45                    "(%s int PRIMARY KEY)" % (t, t))
46            except Error:
47                db.query("DELETE FROM %s._test_schema" % t)
48        try:
49            db.query("CREATE TABLE _test_schema "
50                "(_test int PRIMARY KEY, _i interval, dvar int DEFAULT 999)")
51        except Error:
52            db.query("DELETE FROM _test_schema")
53        try:
54            db.query("CREATE VIEW _test_vschema AS "
55                "SELECT _test, 'abc'::text AS _test2  FROM _test_schema")
56        except Error:
57            pass
58
59    def test_invalidname(self):
60        """Make sure that invalid table names are caught"""
61        db = opendb()
62        self.failUnlessRaises(ProgrammingError, db.get_attnames, 'x.y.z')
63
64    def test_schema(self):
65        """Does it differentiate the same table name in different schemas"""
66        db = opendb()
67        # see if they differentiate the table names properly
68        self.assertEqual(
69            db.get_attnames('_test_schema'),
70            {'_test': 'int', '_i': 'date', 'dvar': 'int'}
71        )
72        self.assertEqual(
73            db.get_attnames('public._test_schema'),
74            {'_test': 'int', '_i': 'date', 'dvar': 'int'}
75        )
76        self.assertEqual(
77            db.get_attnames('_test1._test_schema'),
78            {'_test1': 'int'}
79        )
80        self.assertEqual(
81            db.get_attnames('_test2._test_schema'),
82            {'_test2': 'int'}
83        )
84
85    def test_pkey(self):
86        db = opendb()
87        self.assertEqual(db.pkey('_test_schema'), '_test')
88        self.assertEqual(db.pkey('public._test_schema'), '_test')
89        self.assertEqual(db.pkey('_test1._test_schema'), '_test1')
90
91        self.assertEqual(db.pkey('_test_schema',
92                {'test1': 'a', 'test2.test3': 'b'}),
93                {'public.test1': 'a', 'test2.test3': 'b'})
94        self.assertEqual(db.pkey('test1'), 'a')
95        self.assertEqual(db.pkey('public.test1'), 'a')
96
97    def test_get(self):
98        db = opendb()
99        db.query("INSERT INTO _test_schema VALUES (1234)")
100        db.get('_test_schema', 1234)
101        db.get('_test_schema', 1234, keyname='_test')
102        self.failUnlessRaises(ProgrammingError, db.get, '_test_vschema', 1234)
103        db.get('_test_vschema', 1234, keyname='_test')
104
105    def test_params(self):
106        db = opendb()
107        db.query("INSERT INTO _test_schema VALUES ($1, $2, $3)", 12, None, 34)
108        d = db.get('_test_schema', 12)
109        self.assertEqual(d['dvar'], 34)
110
111    def test_insert(self):
112        db = opendb()
113        d = dict(_test=1234)
114        db.insert('_test_schema', d)
115        self.assertEqual(d['dvar'], 999)
116        db.insert('_test_schema', _test=1235)
117        self.assertEqual(d['dvar'], 999)
118
119    def test_context_manager(self):
120        db = opendb()
121        t = '_test_schema'
122        d = dict(_test=1235)
123        with db:
124            db.insert(t, d)
125            d['_test'] += 1
126            db.insert(t, d)
127        try:
128            with db:
129                d['_test'] += 1
130                db.insert(t, d)
131                db.insert(t, d)
132        except ProgrammingError:
133            pass
134        with db:
135            d['_test'] += 1
136            db.insert(t, d)
137            d['_test'] += 1
138            db.insert(t, d)
139        self.assertTrue(db.get(t, 1235))
140        self.assertTrue(db.get(t, 1236))
141        self.assertRaises(DatabaseError, db.get, t, 1237)
142        self.assertTrue(db.get(t, 1238))
143        self.assertTrue(db.get(t, 1239))
144
145    def test_sqlstate(self):
146        db = opendb()
147        db.query("INSERT INTO _test_schema VALUES (1234)")
148        try:
149            db.query("INSERT INTO _test_schema VALUES (1234)")
150        except DatabaseError, error:
151            # currently PyGreSQL does not support IntegrityError
152            self.assert_(isinstance(error, ProgrammingError))
153            # the SQLSTATE error code for unique violation is 23505
154            self.assertEqual(error.sqlstate, '23505')
155
156    def test_mixed_case(self):
157        db = opendb()
158        try:
159            db.query('CREATE TABLE _test_mc ("_Test" int PRIMARY KEY)')
160        except Error:
161            db.query("DELETE FROM _test_mc")
162        d = dict(_Test=1234)
163        db.insert('_test_mc', d)
164
165    def test_update(self):
166        db = opendb()
167        db.query("INSERT INTO _test_schema VALUES (1234)")
168
169        r = db.get('_test_schema', 1234)
170        r['dvar'] = 123
171        db.update('_test_schema', r)
172        r = db.get('_test_schema', 1234)
173        self.assertEqual(r['dvar'], 123)
174
175        r = db.get('_test_schema', 1234)
176        db.update('_test_schema', _test=1234, dvar=456)
177        r = db.get('_test_schema', 1234)
178        self.assertEqual(r['dvar'], 456)
179
180        r = db.get('_test_schema', 1234)
181        db.update('_test_schema', r, dvar=456)
182        r = db.get('_test_schema', 1234)
183        self.assertEqual(r['dvar'], 456)
184
185    def test_quote(self):
186        db = opendb()
187        q = db._quote
188        self.assertEqual(q(0, 'int'), "0")
189        self.assertEqual(q(0, 'num'), "0")
190        self.assertEqual(q('0', 'int'), "0")
191        self.assertEqual(q('0', 'num'), "0")
192        self.assertEqual(q(1, 'int'), "1")
193        self.assertEqual(q(1, 'text'), "'1'")
194        self.assertEqual(q(1, 'num'), "1")
195        self.assertEqual(q('1', 'int'), "1")
196        self.assertEqual(q('1', 'text'), "'1'")
197        self.assertEqual(q('1', 'num'), "1")
198        self.assertEqual(q(None, 'int'), "NULL")
199        self.assertEqual(q(1, 'money'), "1")
200        self.assertEqual(q('1', 'money'), "1")
201        self.assertEqual(q(1.234, 'money'), "1.234")
202        self.assertEqual(q('1.234', 'money'), "1.234")
203        self.assertEqual(q(0, 'money'), "0")
204        self.assertEqual(q(0.00, 'money'), "0.0")
205        self.assertEqual(q(Decimal('0.00'), 'money'), "0.00")
206        self.assertEqual(q(None, 'money'), "NULL")
207        self.assertEqual(q('', 'money'), "NULL")
208        self.assertEqual(q(0, 'bool'), "'f'")
209        self.assertEqual(q('', 'bool'), "NULL")
210        self.assertEqual(q('f', 'bool'), "'f'")
211        self.assertEqual(q('off', 'bool'), "'f'")
212        self.assertEqual(q('no', 'bool'), "'f'")
213        self.assertEqual(q(1, 'bool'), "'t'")
214        self.assertEqual(q(9999, 'bool'), "'t'")
215        self.assertEqual(q(-9999, 'bool'), "'t'")
216        self.assertEqual(q('1', 'bool'), "'t'")
217        self.assertEqual(q('t', 'bool'), "'t'")
218        self.assertEqual(q('on', 'bool'), "'t'")
219        self.assertEqual(q('yes', 'bool'), "'t'")
220        self.assertEqual(q('true', 'bool'), "'t'")
221        self.assertEqual(q('y', 'bool'), "'t'")
222        self.assertEqual(q('', 'date'), "NULL")
223        self.assertEqual(q(False, 'date'), "NULL")
224        self.assertEqual(q(0, 'date'), "NULL")
225        self.assertEqual(q('some_date', 'date'), "'some_date'")
226        self.assertEqual(q('current_timestamp', 'date'), "current_timestamp")
227        self.assertEqual(q('', 'text'), "''")
228        self.assertEqual(q("'", 'text'), "''''")
229        self.assertEqual(q("\\", 'text'), "'\\\\'")
230
231    def notify_callback(self, arg_dict):
232        if arg_dict:
233            arg_dict['called'] = True
234        else:
235            self.notify_timeout = True
236
237    def test_notify(self, options=None):
238        if not options:
239            options = {}
240        run_as_method = options.get('run_as_method')
241        call_notify = options.get('call_notify')
242        two_payloads = options.get('two_payloads')
243        db = opendb()
244        # Get function under test, can be standalone or DB method.
245        fut = db.notification_handler if run_as_method else partial(
246            NotificationHandler, db)
247        arg_dict = dict(event=None, called=False)
248        self.notify_timeout = False
249        # Listen for 'event_1'.
250        target = fut('event_1', self.notify_callback, arg_dict, 5)
251        thread = Thread(None, target)
252        thread.start()
253        try:
254            # Wait until the thread has started.
255            for n in range(500):
256                if target.listening:
257                    break
258                sleep(0.01)
259            self.assertTrue(target.listening)
260            self.assertTrue(thread.isAlive())
261            # Open another connection for sending notifications.
262            db2 = opendb()
263            # Generate notification from the other connection.
264            if two_payloads:
265                db2.begin()
266            if call_notify:
267                if two_payloads:
268                    target.notify(db2, payload='payload 0')
269                target.notify(db2, payload='payload 1')
270            else:
271                if two_payloads:
272                    db2.query("notify event_1, 'payload 0'")
273                db2.query("notify event_1, 'payload 1'")
274            if two_payloads:
275                db2.commit()
276            # Wait until the notification has been caught.
277            for n in range(500):
278                if arg_dict['called'] or self.notify_timeout:
279                    break
280                sleep(0.01)
281            # Check that callback has been invoked.
282            self.assertTrue(arg_dict['called'])
283            self.assertEqual(arg_dict['event'], 'event_1')
284            self.assertEqual(arg_dict['extra'], 'payload 1')
285            self.assertTrue(isinstance(arg_dict['pid'], int))
286            self.assertFalse(self.notify_timeout)
287            arg_dict['called'] = False
288            self.assertTrue(thread.isAlive())
289            # Generate stop notification.
290            if call_notify:
291                target.notify(db2, stop=True, payload='payload 2')
292            else:
293                db2.query("notify stop_event_1, 'payload 2'")
294            db2.close()
295            # Wait until the notification has been caught.
296            for n in range(500):
297                if arg_dict['called'] or self.notify_timeout:
298                    break
299                sleep(0.01)
300            # Check that callback has been invoked.
301            self.assertTrue(arg_dict['called'])
302            self.assertEqual(arg_dict['event'], 'stop_event_1')
303            self.assertEqual(arg_dict['extra'], 'payload 2')
304            self.assertTrue(isinstance(arg_dict['pid'], int))
305            self.assertFalse(self.notify_timeout)
306            thread.join(5)
307            self.assertFalse(thread.isAlive())
308            self.assertFalse(target.listening)
309        finally:
310            target.close()
311            if thread.isAlive():
312                thread.join(5)
313
314    def test_notify_other_options(self):
315        for run_as_method in False, True:
316            for call_notify in False, True:
317                for two_payloads in False, True:
318                    options = dict(
319                        run_as_method=run_as_method,
320                        call_notify=call_notify,
321                        two_payloads=two_payloads)
322                    if True in options.values():
323                        self.test_notify(options)
324
325    def test_notify_timeout(self):
326        for run_as_method in False, True:
327            db = opendb()
328            # Get function under test, can be standalone or DB method.
329            fut = db.notification_handler if run_as_method else partial(
330                NotificationHandler, db)
331            arg_dict = dict(event=None, called=False)
332            self.notify_timeout = False
333            # Listen for 'event_1' with timeout of 10ms.
334            target = fut('event_1', self.notify_callback, arg_dict, 0.01)
335            thread = Thread(None, target)
336            thread.start()
337            # Sleep 20ms, long enough to time out.
338            sleep(0.02)
339            # Verify that we've indeed timed out.
340            self.assertFalse(arg_dict.get('called'))
341            self.assertTrue(self.notify_timeout)
342            self.assertFalse(thread.isAlive())
343            self.assertFalse(target.listening)
344            target.close()
345
346
347if __name__ == '__main__':
348    suite = unittest.TestSuite()
349
350    if len(sys.argv) > 1: test_list = sys.argv[1:]
351    else: test_list = unittest.getTestCaseNames(UtilityTest, 'test_')
352
353    if len(sys.argv) == 2 and sys.argv[1] == '-l':
354        print '\n'.join(unittest.getTestCaseNames(UtilityTest, 'test_'))
355        sys.exit(1)
356
357    for test_name in test_list:
358        try:
359            suite.addTest(UtilityTest(test_name))
360        except:
361            print "\n ERROR: %s.\n" % sys.exc_value
362            sys.exit(1)
363
364    rc = unittest.TextTestRunner(verbosity=1).run(suite)
365    sys.exit(len(rc.errors+rc.failures) != 0)
366
Note: See TracBrowser for help on using the repository browser.