Changeset 748 for trunk/tests


Ignore:
Timestamp:
Jan 15, 2016, 9:25:31 AM (4 years ago)
Author:
cito
Message:

Add method truncate() to DB wrapper class

This methods can be used to quickly truncate tables.

Since this is pretty useful and will not break anything, I have
also back ported this addition to the 4.x branch.

Everything is well documented and tested, of course.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/tests/test_classic_dbwrapper.py

    r747 r748  
    108108            'set_notice_receiver', 'set_parameter',
    109109            'source', 'start', 'status',
    110             'transaction',
     110            'transaction', 'truncate',
    111111            'unescape_bytea', 'update', 'upsert',
    112112            'use_regtypes', 'user',
     
    15151515        self.assertEqual(r[0][0], 0)
    15161516
     1517    def testTruncate(self):
     1518        truncate = self.db.truncate
     1519        self.assertRaises(TypeError, truncate, None)
     1520        self.assertRaises(TypeError, truncate, 42)
     1521        self.assertRaises(TypeError, truncate, dict(test_table=None))
     1522        query = self.db.query
     1523        query("drop table if exists test_table")
     1524        self.addCleanup(query, "drop table test_table")
     1525        query("create table test_table (n smallint)")
     1526        for i in range(3):
     1527            query("insert into test_table values (1)")
     1528        q = "select count(*) from test_table"
     1529        r = query(q).getresult()[0][0]
     1530        self.assertEqual(r, 3)
     1531        truncate('test_table')
     1532        r = query(q).getresult()[0][0]
     1533        self.assertEqual(r, 0)
     1534        for i in range(3):
     1535            query("insert into test_table values (1)")
     1536        r = query(q).getresult()[0][0]
     1537        self.assertEqual(r, 3)
     1538        truncate('public.test_table')
     1539        r = query(q).getresult()[0][0]
     1540        self.assertEqual(r, 0)
     1541        query("drop table if exists test_table_2")
     1542        self.addCleanup(query, "drop table test_table_2")
     1543        query('create table test_table_2 (n smallint)')
     1544        for t in (list, tuple, set):
     1545            for i in range(3):
     1546                query("insert into test_table values (1)")
     1547                query("insert into test_table_2 values (2)")
     1548            q = ("select (select count(*) from test_table),"
     1549                " (select count(*) from test_table_2)")
     1550            r = query(q).getresult()[0]
     1551            self.assertEqual(r, (3, 3))
     1552            truncate(t(['test_table', 'test_table_2']))
     1553            r = query(q).getresult()[0]
     1554            self.assertEqual(r, (0, 0))
     1555
     1556    def testTruncateRestart(self):
     1557        truncate = self.db.truncate
     1558        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
     1559        query = self.db.query
     1560        query("drop table if exists test_table")
     1561        self.addCleanup(query, "drop table test_table")
     1562        query("create table test_table (n serial, t text)")
     1563        for n in range(3):
     1564            query("insert into test_table (t) values ('test')")
     1565        q = "select count(n), min(n), max(n) from test_table"
     1566        r = query(q).getresult()[0]
     1567        self.assertEqual(r, (3, 1, 3))
     1568        truncate('test_table')
     1569        r = query(q).getresult()[0]
     1570        self.assertEqual(r, (0, None, None))
     1571        for n in range(3):
     1572            query("insert into test_table (t) values ('test')")
     1573        r = query(q).getresult()[0]
     1574        self.assertEqual(r, (3, 4, 6))
     1575        truncate('test_table', restart=True)
     1576        r = query(q).getresult()[0]
     1577        self.assertEqual(r, (0, None, None))
     1578        for n in range(3):
     1579            query("insert into test_table (t) values ('test')")
     1580        r = query(q).getresult()[0]
     1581        self.assertEqual(r, (3, 1, 3))
     1582
     1583    def testTruncateCascade(self):
     1584        truncate = self.db.truncate
     1585        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
     1586        query = self.db.query
     1587        query("drop table if exists test_child")
     1588        query("drop table if exists test_parent")
     1589        self.addCleanup(query, "drop table test_parent")
     1590        query("create table test_parent (n smallint primary key)")
     1591        self.addCleanup(query, "drop table test_child")
     1592        query("create table test_child ("
     1593            " n smallint primary key references test_parent (n))")
     1594        for n in range(3):
     1595            query("insert into test_parent (n) values (%d)" % n)
     1596            query("insert into test_child (n) values (%d)" % n)
     1597        q = ("select (select count(*) from test_parent),"
     1598            " (select count(*) from test_child)")
     1599        r = query(q).getresult()[0]
     1600        self.assertEqual(r, (3, 3))
     1601        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
     1602        truncate(['test_parent', 'test_child'])
     1603        r = query(q).getresult()[0]
     1604        self.assertEqual(r, (0, 0))
     1605        for n in range(3):
     1606            query("insert into test_parent (n) values (%d)" % n)
     1607            query("insert into test_child (n) values (%d)" % n)
     1608        r = query(q).getresult()[0]
     1609        self.assertEqual(r, (3, 3))
     1610        truncate('test_parent', cascade=True)
     1611        r = query(q).getresult()[0]
     1612        self.assertEqual(r, (0, 0))
     1613        for n in range(3):
     1614            query("insert into test_parent (n) values (%d)" % n)
     1615            query("insert into test_child (n) values (%d)" % n)
     1616        r = query(q).getresult()[0]
     1617        self.assertEqual(r, (3, 3))
     1618        truncate('test_child')
     1619        r = query(q).getresult()[0]
     1620        self.assertEqual(r, (3, 0))
     1621        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
     1622        truncate('test_parent', cascade=True)
     1623        r = query(q).getresult()[0]
     1624        self.assertEqual(r, (0, 0))
     1625
     1626    def testTruncateOnly(self):
     1627        truncate = self.db.truncate
     1628        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
     1629        query = self.db.query
     1630        query("drop table if exists test_child")
     1631        query("drop table if exists test_parent")
     1632        self.addCleanup(query, "drop table test_parent")
     1633        query("create table test_parent (n smallint)")
     1634        self.addCleanup(query, "drop table test_child")
     1635        query("create table test_child ("
     1636            " m smallint) inherits (test_parent)")
     1637        for n in range(3):
     1638            query("insert into test_parent (n) values (1)")
     1639            query("insert into test_child (n, m) values (2, 3)")
     1640        q = ("select (select count(*) from test_parent),"
     1641            " (select count(*) from test_child)")
     1642        r = query(q).getresult()[0]
     1643        self.assertEqual(r, (6, 3))
     1644        truncate('test_parent')
     1645        r = query(q).getresult()[0]
     1646        self.assertEqual(r, (0, 0))
     1647        for n in range(3):
     1648            query("insert into test_parent (n) values (1)")
     1649            query("insert into test_child (n, m) values (2, 3)")
     1650        r = query(q).getresult()[0]
     1651        self.assertEqual(r, (6, 3))
     1652        truncate('test_parent*')
     1653        r = query(q).getresult()[0]
     1654        self.assertEqual(r, (0, 0))
     1655        for n in range(3):
     1656            query("insert into test_parent (n) values (1)")
     1657            query("insert into test_child (n, m) values (2, 3)")
     1658        r = query(q).getresult()[0]
     1659        self.assertEqual(r, (6, 3))
     1660        truncate('test_parent', only=True)
     1661        r = query(q).getresult()[0]
     1662        self.assertEqual(r, (3, 3))
     1663        truncate('test_parent', only=False)
     1664        r = query(q).getresult()[0]
     1665        self.assertEqual(r, (0, 0))
     1666        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
     1667        truncate('test_parent*', only=False)
     1668        query("drop table if exists test_parent_2")
     1669        self.addCleanup(query, "drop table test_parent_2")
     1670        query("create table test_parent_2 (n smallint)")
     1671        query("drop table if exists test_child_2")
     1672        self.addCleanup(query, "drop table test_child_2")
     1673        query("create table test_child_2 ("
     1674            " m smallint) inherits (test_parent_2)")
     1675        for n in range(3):
     1676            query("insert into test_parent (n) values (1)")
     1677            query("insert into test_child (n, m) values (2, 3)")
     1678            query("insert into test_parent_2 (n) values (1)")
     1679            query("insert into test_child_2 (n, m) values (2, 3)")
     1680        q = ("select (select count(*) from test_parent),"
     1681            " (select count(*) from test_child),"
     1682            " (select count(*) from test_parent_2),"
     1683            " (select count(*) from test_child_2)")
     1684        r = query(q).getresult()[0]
     1685        self.assertEqual(r, (6, 3, 6, 3))
     1686        truncate(['test_parent', 'test_parent_2'], only=[False, True])
     1687        r = query(q).getresult()[0]
     1688        self.assertEqual(r, (0, 0, 3, 3))
     1689        truncate(['test_parent', 'test_parent_2'], only=False)
     1690        r = query(q).getresult()[0]
     1691        self.assertEqual(r, (0, 0, 0, 0))
     1692        self.assertRaises(ValueError, truncate,
     1693            ['test_parent*', 'test_child'], only=[True, False])
     1694        truncate(['test_parent*', 'test_child'], only=[False, True])
     1695
     1696    def testTruncateQuoted(self):
     1697        truncate = self.db.truncate
     1698        query = self.db.query
     1699        table = "test table for truncate()"
     1700        query('drop table if exists "%s"' % table)
     1701        self.addCleanup(query, 'drop table "%s"' % table)
     1702        query('create table "%s" (n smallint)' % table)
     1703        for i in range(3):
     1704            query('insert into "%s" values (1)' % table)
     1705        q = 'select count(*) from "%s"' % table
     1706        r = query(q).getresult()[0][0]
     1707        self.assertEqual(r, 3)
     1708        truncate(table)
     1709        r = query(q).getresult()[0][0]
     1710        self.assertEqual(r, 0)
     1711        for i in range(3):
     1712            query('insert into "%s" values (1)' % table)
     1713        r = query(q).getresult()[0][0]
     1714        self.assertEqual(r, 3)
     1715        truncate('public."%s"' % table)
     1716        r = query(q).getresult()[0][0]
     1717        self.assertEqual(r, 0)
     1718
    15171719    def testTransaction(self):
    15181720        query = self.db.query
Note: See TracChangeset for help on using the changeset viewer.