Changeset 539


Ignore:
Timestamp:
Nov 18, 2015, 11:24:06 AM (4 years ago)
Author:
cito
Message:

Refactor the test_pg module into three smaller modules

This refactoring greatly simplifies the huge test_pg module
by splitting it into three individual modules and making the
code more modern, shorter and readable. Also, instead of
creating a test database, we require an existing database
that can be configured in an external module, just like the
other existing test modules. We now also use similar names.

When running the tests from the command line, you can pass
the -v flag for verbosity and the -f flag for fast failing.

Location:
branches/4.x/module
Files:
2 added
1 moved

Legend:

Unmodified
Added
Removed
  • branches/4.x/module/TEST_PyGreSQL_classic_dbwrapper.py

    r537 r539  
    11#! /usr/bin/python
    22# -*- coding: utf-8 -*-
    3 #
    4 # test_pg.py
    5 #
    6 # Written by Christoph Zwerschke
    7 #
    8 # $Id$
    9 #
    10 
    11 """Test the classic PyGreSQL interface in the pg module.
    12 
    13 The testing is done against a real local PostgreSQL database.
    14 
    15 There are a few drawbacks:
    16   * A local PostgreSQL database must be up and running, and
    17     the database user who is running the tests must be a trusted superuser.
    18   * The PostgreSQL database cluster must have been created with UTF-8 encoding.
    19   * The performance of the API is not tested.
    20   * Connecting to a remote host is not tested.
    21   * Passing user, password and options is not tested.
    22   * Table privilege problems (e.g. insert but no select) are not tested.
    23   * Status and error messages from the connection are not tested.
    24   * It would be more reasonable to create a test for the underlying
    25     shared library functions in the _pg module and assume they are ok.
    26     The pg and pgdb modules should be tested against _pg mock functions.
     3
     4"""Test the classic PyGreSQL interface.
     5
     6Sub-tests for the DB wrapper object.
     7
     8Contributed by Christoph Zwerschke.
     9
     10These tests need a database to test against.
    2711
    2812"""
     
    3014from __future__ import with_statement
    3115
    32 import pg
    33 
    34 import sys
    35 import unittest
    36 import locale
     16try:
     17    import unittest2 as unittest  # for Python < 2.6
     18except ImportError:
     19    import unittest
     20
     21import pg  # the module under test
    3722
    3823from decimal import Decimal
     24
     25# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
     26# get our information from that.  Otherwise we use the defaults.
     27# The current user must have create schema privilege on the database.
     28dbname = 'unittest'
     29dbhost = None
     30dbport = 5432
     31
     32debug = False  # let DB wrapper print debugging output
     33
    3934try:
    40     from collections import namedtuple
    41 except ImportError:  # Python < 2.6
    42     namedtuple = None
    43 
    44 debug = False
    45 
    46 locale.setlocale(locale.LC_ALL, '')
    47 encoding = locale.getlocale()[1] or 'utf-8'
    48 if encoding.lower() == 'utf-8':
    49     recode = lambda s: s
    50 else:
    51     recode = lambda s: s.decode('utf-8').encode(encoding)
    52 
    53 
    54 def smart_ddl(conn, cmd):
    55     """Execute DDL, but don't complain about minor things."""
    56     try:
    57         if cmd.startswith('create table '):
    58             i = cmd.find(' as select ')
    59             if i < 0:
    60                 i = len(cmd)
    61             conn.query(cmd[:i] + ' with oids' + cmd[i:])
    62         else:
    63             conn.query(cmd)
    64     except pg.ProgrammingError:
    65         if (cmd.startswith('drop table ')
    66             or cmd.startswith('set ')
    67             or cmd.startswith('alter database ')):
    68             pass
    69         elif cmd.startswith('create table '):
    70             conn.query(cmd)
    71         else:
    72             raise
    73 
    74 
    75 class TestAuxiliaryFunctions(unittest.TestCase):
    76     """Test the auxiliary functions external to the connection class."""
    77 
    78     def testIsQuoted(self):
    79         f = pg._is_quoted
    80         self.assertTrue(f('A'))
    81         self.assertTrue(f('0'))
    82         self.assertTrue(f('#'))
    83         self.assertTrue(f('*'))
    84         self.assertTrue(f('.'))
    85         self.assertTrue(f(' '))
    86         self.assertTrue(f('a b'))
    87         self.assertTrue(f('a+b'))
    88         self.assertTrue(f('a*b'))
    89         self.assertTrue(f('a.b'))
    90         self.assertTrue(f('0ab'))
    91         self.assertTrue(f('aBc'))
    92         self.assertTrue(f('ABC'))
    93         self.assertTrue(f('"a"'))
    94         self.assertTrue(not f('a'))
    95         self.assertTrue(not f('a0'))
    96         self.assertTrue(not f('_'))
    97         self.assertTrue(not f('_a'))
    98         self.assertTrue(not f('_0'))
    99         self.assertTrue(not f('_a_0_'))
    100         self.assertTrue(not f('ab'))
    101         self.assertTrue(not f('ab0'))
    102         self.assertTrue(not f('abc'))
    103         self.assertTrue(not f('abc'))
    104         if recode('À').isalpha():
    105             e = lambda s: f(recode(s))
    106             self.assertTrue(not e('À'))
    107             self.assertTrue(e('Ä'))
    108             self.assertTrue(not e('kÀse'))
    109             self.assertTrue(e('KÀse'))
    110             self.assertTrue(not e('emmentaler_kÀse'))
    111             self.assertTrue(e('emmentaler kÀse'))
    112             self.assertTrue(e('EmmentalerKÀse'))
    113             self.assertTrue(e('Emmentaler KÀse'))
    114 
    115     def testIsUnquoted(self):
    116         f = pg._is_unquoted
    117         self.assertTrue(f('A'))
    118         self.assertTrue(not f('0'))
    119         self.assertTrue(not f('#'))
    120         self.assertTrue(not f('*'))
    121         self.assertTrue(not f('.'))
    122         self.assertTrue(not f(' '))
    123         self.assertTrue(not f('a b'))
    124         self.assertTrue(not f('a+b'))
    125         self.assertTrue(not f('a*b'))
    126         self.assertTrue(not f('a.b'))
    127         self.assertTrue(not f('0ab'))
    128         self.assertTrue(f('aBc'))
    129         self.assertTrue(f('ABC'))
    130         self.assertTrue(not f('"a"'))
    131         self.assertTrue(f('a0'))
    132         self.assertTrue(f('_'))
    133         self.assertTrue(f('_a'))
    134         self.assertTrue(f('_0'))
    135         self.assertTrue(f('_a_0_'))
    136         self.assertTrue(f('ab'))
    137         self.assertTrue(f('ab0'))
    138         self.assertTrue(f('abc'))
    139         if recode('À').isalpha():
    140             e = lambda s: f(recode(s))
    141             self.assertTrue(e('À'))
    142             self.assertTrue(e('Ä'))
    143             self.assertTrue(e('kÀse'))
    144             self.assertTrue(e('KÀse'))
    145             self.assertTrue(e('emmentaler_kÀse'))
    146             self.assertTrue(not e('emmentaler kÀse'))
    147             self.assertTrue(e('EmmentalerKÀse'))
    148             self.assertTrue(not e('Emmentaler KÀse'))
    149 
    150     def testSplitFirstPart(self):
    151         f = pg._split_first_part
    152         self.assertEqual(f('a.b'), ['a', 'b'])
    153         self.assertEqual(f('a.b.c'), ['a', 'b.c'])
    154         self.assertEqual(f('"a.b".c'), ['a.b', 'c'])
    155         self.assertEqual(f('a."b.c"'), ['a', '"b.c"'])
    156         self.assertEqual(f('A.b.c'), ['a', 'b.c'])
    157         self.assertEqual(f('Ab.c'), ['ab', 'c'])
    158         self.assertEqual(f('aB.c'), ['ab', 'c'])
    159         self.assertEqual(f('AB.c'), ['ab', 'c'])
    160         self.assertEqual(f('A b.c'), ['A b', 'c'])
    161         self.assertEqual(f('a B.c'), ['a B', 'c'])
    162         self.assertEqual(f('"A".b.c'), ['A', 'b.c'])
    163         self.assertEqual(f('"A""B".c'), ['A"B', 'c'])
    164         self.assertEqual(f('a.b.c.d.e.f.g'), ['a', 'b.c.d.e.f.g'])
    165         self.assertEqual(f('"a.b.c.d.e.f".g'), ['a.b.c.d.e.f', 'g'])
    166         self.assertEqual(f('a.B.c.D.e.F.g'), ['a', 'B.c.D.e.F.g'])
    167         self.assertEqual(f('A.b.C.d.E.f.G'), ['a', 'b.C.d.E.f.G'])
    168 
    169     def testSplitParts(self):
    170         f = pg._split_parts
    171         self.assertEqual(f('a.b'), ['a', 'b'])
    172         self.assertEqual(f('a.b.c'), ['a', 'b', 'c'])
    173         self.assertEqual(f('"a.b".c'), ['a.b', 'c'])
    174         self.assertEqual(f('a."b.c"'), ['a', 'b.c'])
    175         self.assertEqual(f('A.b.c'), ['a', 'b', 'c'])
    176         self.assertEqual(f('Ab.c'), ['ab', 'c'])
    177         self.assertEqual(f('aB.c'), ['ab', 'c'])
    178         self.assertEqual(f('AB.c'), ['ab', 'c'])
    179         self.assertEqual(f('A b.c'), ['A b', 'c'])
    180         self.assertEqual(f('a B.c'), ['a B', 'c'])
    181         self.assertEqual(f('"A".b.c'), ['A', 'b', 'c'])
    182         self.assertEqual(f('"A""B".c'), ['A"B', 'c'])
    183         self.assertEqual(f('a.b.c.d.e.f.g'),
    184             ['a', 'b', 'c', 'd', 'e', 'f', 'g'])
    185         self.assertEqual(f('"a.b.c.d.e.f".g'),
    186             ['a.b.c.d.e.f', 'g'])
    187         self.assertEqual(f('a.B.c.D.e.F.g'),
    188             ['a', 'b', 'c', 'd', 'e', 'f', 'g'])
    189         self.assertEqual(f('A.b.C.d.E.f.G'),
    190             ['a', 'b', 'c', 'd', 'e', 'f', 'g'])
    191 
    192     def testJoinParts(self):
    193         f = pg._join_parts
    194         self.assertEqual(f(('a',)), 'a')
    195         self.assertEqual(f(('a', 'b')), 'a.b')
    196         self.assertEqual(f(('a', 'b', 'c')), 'a.b.c')
    197         self.assertEqual(f(('a', 'b', 'c', 'd', 'e', 'f', 'g')),
    198             'a.b.c.d.e.f.g')
    199         self.assertEqual(f(('A', 'b')), '"A".b')
    200         self.assertEqual(f(('a', 'B')), 'a."B"')
    201         self.assertEqual(f(('a b', 'c')), '"a b".c')
    202         self.assertEqual(f(('a', 'b c')), 'a."b c"')
    203         self.assertEqual(f(('a_b', 'c')), 'a_b.c')
    204         self.assertEqual(f(('a', 'b_c')), 'a.b_c')
    205         self.assertEqual(f(('0', 'a')), '"0".a')
    206         self.assertEqual(f(('0_', 'a')), '"0_".a')
    207         self.assertEqual(f(('_0', 'a')), '_0.a')
    208         self.assertEqual(f(('_a', 'b')), '_a.b')
    209         self.assertEqual(f(('a', 'B', '0', 'c0', 'C0',
    210             'd e', 'f_g', 'h.i', 'jklm', 'nopq')),
    211             'a."B"."0".c0."C0"."d e".f_g."h.i".jklm.nopq')
    212 
    213     def testOidKey(self):
    214         f = pg._oid_key
    215         self.assertEqual(f('a'), 'oid(a)')
    216         self.assertEqual(f('a.b'), 'oid(a.b)')
    217 
    218 
    219 class TestHasConnect(unittest.TestCase):
    220     """Test existence of basic pg module functions."""
    221 
    222     def testhasPgError(self):
    223         self.assertTrue(issubclass(pg.Error, StandardError))
    224 
    225     def testhasPgWarning(self):
    226         self.assertTrue(issubclass(pg.Warning, StandardError))
    227 
    228     def testhasPgInterfaceError(self):
    229         self.assertTrue(issubclass(pg.InterfaceError, pg.Error))
    230 
    231     def testhasPgDatabaseError(self):
    232         self.assertTrue(issubclass(pg.DatabaseError, pg.Error))
    233 
    234     def testhasPgInternalError(self):
    235         self.assertTrue(issubclass(pg.InternalError, pg.DatabaseError))
    236 
    237     def testhasPgOperationalError(self):
    238         self.assertTrue(issubclass(pg.OperationalError, pg.DatabaseError))
    239 
    240     def testhasPgProgrammingError(self):
    241         self.assertTrue(issubclass(pg.ProgrammingError, pg.DatabaseError))
    242 
    243     def testhasPgIntegrityError(self):
    244         self.assertTrue(issubclass(pg.IntegrityError, pg.DatabaseError))
    245 
    246     def testhasPgDataError(self):
    247         self.assertTrue(issubclass(pg.DataError, pg.DatabaseError))
    248 
    249     def testhasPgNotSupportedError(self):
    250         self.assertTrue(issubclass(pg.NotSupportedError, pg.DatabaseError))
    251 
    252     def testhasConnect(self):
    253         self.assertTrue(callable(pg.connect))
    254 
    255     def testhasEscapeString(self):
    256         self.assertTrue(callable(pg.escape_string))
    257 
    258     def testhasEscapeBytea(self):
    259         self.assertTrue(callable(pg.escape_bytea))
    260 
    261     def testhasUnescapeBytea(self):
    262         self.assertTrue(callable(pg.unescape_bytea))
    263 
    264     def testDefHost(self):
    265         d0 = pg.get_defhost()
    266         d1 = 'pgtesthost'
    267         pg.set_defhost(d1)
    268         self.assertEqual(pg.get_defhost(), d1)
    269         pg.set_defhost(d0)
    270         self.assertEqual(pg.get_defhost(), d0)
    271 
    272     def testDefPort(self):
    273         d0 = pg.get_defport()
    274         d1 = 1234
    275         pg.set_defport(d1)
    276         self.assertEqual(pg.get_defport(), d1)
    277         if d0 is None:
    278             d0 = -1
    279         pg.set_defport(d0)
    280         if d0 == -1:
    281             d0 = None
    282         self.assertEqual(pg.get_defport(), d0)
    283 
    284     def testDefOpt(self):
    285         d0 = pg.get_defopt()
    286         d1 = '-h pgtesthost -p 1234'
    287         pg.set_defopt(d1)
    288         self.assertEqual(pg.get_defopt(), d1)
    289         pg.set_defopt(d0)
    290         self.assertEqual(pg.get_defopt(), d0)
    291 
    292     def testDefTty(self):
    293         d0 = pg.get_deftty()
    294         d1 = 'pgtesttty'
    295         pg.set_deftty(d1)
    296         self.assertEqual(pg.get_deftty(), d1)
    297         pg.set_deftty(d0)
    298         self.assertEqual(pg.get_deftty(), d0)
    299 
    300     def testDefBase(self):
    301         d0 = pg.get_defbase()
    302         d1 = 'pgtestdb'
    303         pg.set_defbase(d1)
    304         self.assertEqual(pg.get_defbase(), d1)
    305         pg.set_defbase(d0)
    306         self.assertEqual(pg.get_defbase(), d0)
    307 
    308 
    309 class TestEscapeFunctions(unittest.TestCase):
    310     """"Test pg escape and unescape functions."""
    311 
    312     def testEscapeString(self):
    313         self.assertEqual(pg.escape_string('plain'), 'plain')
    314         self.assertEqual(pg.escape_string(
    315             "that's k\xe4se"), "that''s k\xe4se")
    316         self.assertEqual(pg.escape_string(
    317             r"It's fine to have a \ inside."),
    318             r"It''s fine to have a \\ inside.")
    319 
    320     def testEscapeBytea(self):
    321         self.assertEqual(pg.escape_bytea('plain'), 'plain')
    322         self.assertEqual(pg.escape_bytea(
    323             "that's k\xe4se"), "that''s k\\\\344se")
    324         self.assertEqual(pg.escape_bytea(
    325             'O\x00ps\xff!'), r'O\\000ps\\377!')
    326 
    327     def testUnescapeBytea(self):
    328         self.assertEqual(pg.unescape_bytea('plain'), 'plain')
    329         self.assertEqual(pg.unescape_bytea(
    330             "that's k\\344se"), "that's k\xe4se")
    331         self.assertEqual(pg.unescape_bytea(
    332             r'O\000ps\377!'), 'O\x00ps\xff!')
    333 
    334 
    335 class TestCanConnect(unittest.TestCase):
    336     """Test whether a basic connection to PostgreSQL is possible."""
    337 
    338     def testCanConnectTemplate1(self):
    339         dbname = 'template1'
    340         try:
    341             connection = pg.connect(dbname)
    342         except pg.Error:
    343             self.fail('Cannot connect to database ' + dbname)
    344         try:
    345             connection.close()
    346         except pg.Error:
    347             self.fail('Cannot close the database connection')
    348 
    349 
    350 class TestConnectObject(unittest.TestCase):
    351     """"Test existence of basic pg connection methods."""
    352 
    353     def setUp(self):
    354         dbname = 'template1'
    355         self.dbname = dbname
    356         self.connection = pg.connect(dbname)
    357 
    358     def tearDown(self):
    359         self.connection.close()
    360 
    361     def testAllConnectAttributes(self):
    362         attributes = '''db error host options port
    363             protocol_version server_version status tty user'''.split()
    364         connection_attributes = [a for a in dir(self.connection)
    365             if not callable(eval("self.connection." + a))]
    366         self.assertEqual(attributes, connection_attributes)
    367 
    368     def testAllConnectMethods(self):
    369         methods = '''cancel close endcopy
    370             escape_bytea escape_identifier escape_literal escape_string
    371             fileno get_notice_receiver getline getlo getnotify
    372             inserttable locreate loimport parameter putline query reset
    373             set_notice_receiver source transaction'''.split()
    374         connection_methods = [a for a in dir(self.connection)
    375             if callable(eval("self.connection." + a))]
    376         self.assertEqual(methods, connection_methods)
    377 
    378     def testAttributeDb(self):
    379         self.assertEqual(self.connection.db, self.dbname)
    380 
    381     def testAttributeError(self):
    382         error = self.connection.error
    383         self.assertTrue(not error or 'krb5_' in error)
    384 
    385     def testAttributeHost(self):
    386         def_host = 'localhost'
    387         self.assertEqual(self.connection.host, def_host)
    388 
    389     def testAttributeOptions(self):
    390         no_options = ''
    391         self.assertEqual(self.connection.options, no_options)
    392 
    393     def testAttributePort(self):
    394         def_port = 5432
    395         self.assertEqual(self.connection.port, def_port)
    396 
    397     def testAttributeProtocolVersion(self):
    398         protocol_version = self.connection.protocol_version
    399         self.assertTrue(isinstance(protocol_version, int))
    400         self.assertTrue(2 <= protocol_version < 4)
    401 
    402     def testAttributeServerVersion(self):
    403         server_version = self.connection.server_version
    404         self.assertTrue(isinstance(server_version, int))
    405         self.assertTrue(70400 <= server_version < 100000)
    406 
    407     def testAttributeStatus(self):
    408         status_ok = 1
    409         self.assertEqual(self.connection.status, status_ok)
    410 
    411     def testAttributeTty(self):
    412         def_tty = ''
    413         self.assertEqual(self.connection.tty, def_tty)
    414 
    415     def testAttributeUser(self):
    416         no_user = 'Deprecated facility'
    417         user = self.connection.user
    418         self.assertTrue(user)
    419         self.assertNotEqual(user, no_user)
    420 
    421     def testMethodQuery(self):
    422         self.connection.query("select 1+1")
    423         self.connection.query("select 1+$1", (1,))
    424         self.connection.query("select 1+$1+$2", (2, 3))
    425         self.connection.query("select 1+$1+$2", [2, 3])
    426 
    427     def testMethodQueryEmpty(self):
    428         self.assertRaises(ValueError, self.connection.query, '')
    429 
    430     def testMethodEndcopy(self):
    431         try:
    432             self.connection.endcopy()
    433         except IOError:
    434             pass
    435 
    436     def testMethodClose(self):
    437         self.connection.close()
    438         try:
    439             self.connection.reset()
    440         except (pg.Error, TypeError):
    441             pass
    442         else:
    443             self.fail('Reset should give an error for a closed connection')
    444         self.assertRaises(pg.InternalError, self.connection.close)
    445         try:
    446             self.connection.query('select 1')
    447         except (pg.Error, TypeError):
    448             pass
    449         else:
    450             self.fail('Query should give an error for a closed connection')
    451         self.connection = pg.connect(self.dbname)
    452 
    453 
    454 class TestSimpleQueries(unittest.TestCase):
    455     """"Test simple queries via a basic pg connection."""
    456 
    457     # Test database needed: must be run as a DBTestSuite.
    458 
    459     def setUp(self):
    460         dbname = DBTestSuite.dbname
    461         self.c = pg.connect(dbname)
    462 
    463     def tearDown(self):
    464         self.c.close()
    465 
    466     def testSelect0(self):
    467         q = "select 0"
    468         self.c.query(q)
    469 
    470     def testSelect0Semicolon(self):
    471         q = "select 0;"
    472         self.c.query(q)
    473 
    474     def testSelectDotSemicolon(self):
    475         q = "select .;"
    476         self.assertRaises(pg.ProgrammingError, self.c.query, q)
    477 
    478     def testGetresult(self):
    479         q = "select 0"
    480         result = [(0,)]
    481         r = self.c.query(q).getresult()
    482         self.assertEqual(r, result)
    483 
    484     def testDictresult(self):
    485         q = "select 0 as alias0"
    486         result = [{'alias0': 0}]
    487         r = self.c.query(q).dictresult()
    488         self.assertEqual(r, result)
    489 
    490     def testNamedresult(self):
    491         if namedtuple:
    492             q = "select 0 as alias0"
    493             result = [(0,)]
    494             r = self.c.query(q).namedresult()
    495             self.assertEqual(r, result)
    496             v = r[0]
    497             self.assertEqual(v._fields, ('alias0',))
    498             self.assertEqual(v.alias0, 0)
    499 
    500     def testGet3Cols(self):
    501         q = "select 1,2,3"
    502         result = [(1, 2, 3)]
    503         r = self.c.query(q).getresult()
    504         self.assertEqual(r, result)
    505 
    506     def testGet3DictCols(self):
    507         q = "select 1 as a,2 as b,3 as c"
    508         result = [dict(a=1, b=2, c=3)]
    509         r = self.c.query(q).dictresult()
    510         self.assertEqual(r, result)
    511 
    512     def testGet3NamedCols(self):
    513         if namedtuple:
    514             q = "select 1 as a,2 as b,3 as c"
    515             result = [(1, 2, 3)]
    516             r = self.c.query(q).namedresult()
    517             self.assertEqual(r, result)
    518             v = r[0]
    519             self.assertEqual(v._fields, ('a', 'b', 'c'))
    520             self.assertEqual(v.b, 2)
    521 
    522     def testGet3Rows(self):
    523         q = "select 3 union select 1 union select 2 order by 1"
    524         result = [(1,), (2,), (3,)]
    525         r = self.c.query(q).getresult()
    526         self.assertEqual(r, result)
    527 
    528     def testGet3DictRows(self):
    529         q = ("select 3 as alias3"
    530             " union select 1 union select 2 order by 1")
    531         result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
    532         r = self.c.query(q).dictresult()
    533         self.assertEqual(r, result)
    534 
    535     def testGet3NamedRows(self):
    536         if namedtuple:
    537             q = ("select 3 as alias3"
    538                 " union select 1 union select 2 order by 1")
    539             result = [(1,), (2,), (3,)]
    540             r = self.c.query(q).namedresult()
    541             self.assertEqual(r, result)
    542             for v in r:
    543                 self.assertEqual(v._fields, ('alias3',))
    544 
    545     def testDictresultNames(self):
    546         q = "select 'MixedCase' as MixedCaseAlias"
    547         result = [{'mixedcasealias': 'MixedCase'}]
    548         r = self.c.query(q).dictresult()
    549         self.assertEqual(r, result)
    550         q = "select 'MixedCase' as \"MixedCaseAlias\""
    551         result = [{'MixedCaseAlias': 'MixedCase'}]
    552         r = self.c.query(q).dictresult()
    553         self.assertEqual(r, result)
    554 
    555     def testNamedresultNames(self):
    556         if namedtuple:
    557             q = "select 'MixedCase' as MixedCaseAlias"
    558             result = [('MixedCase',)]
    559             r = self.c.query(q).namedresult()
    560             self.assertEqual(r, result)
    561             v = r[0]
    562             self.assertEqual(v._fields, ('mixedcasealias',))
    563             self.assertEqual(v.mixedcasealias, 'MixedCase')
    564             q = "select 'MixedCase' as \"MixedCaseAlias\""
    565             r = self.c.query(q).namedresult()
    566             self.assertEqual(r, result)
    567             v = r[0]
    568             self.assertEqual(v._fields, ('MixedCaseAlias',))
    569             self.assertEqual(v.MixedCaseAlias, 'MixedCase')
    570 
    571     def testBigGetresult(self):
    572         num_cols = 100
    573         num_rows = 100
    574         q = "select " + ','.join(map(str, xrange(num_cols)))
    575         q = ' union all '.join((q,) * num_rows)
    576         r = self.c.query(q).getresult()
    577         result = [tuple(range(num_cols))] * num_rows
    578         self.assertEqual(r, result)
    579 
    580     def testListfields(self):
    581         q = ('select 0 as a, 0 as b, 0 as c,'
    582             ' 0 as c, 0 as b, 0 as a,'
    583             ' 0 as lowercase, 0 as UPPERCASE,'
    584             ' 0 as MixedCase, 0 as "MixedCase",'
    585             ' 0 as a_long_name_with_underscores,'
    586             ' 0 as "A long name with Blanks"')
    587         r = self.c.query(q).listfields()
    588         result = ('a', 'b', 'c', 'c', 'b', 'a',
    589             'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
    590             'a_long_name_with_underscores',
    591             'A long name with Blanks')
    592         self.assertEqual(r, result)
    593 
    594     def testFieldname(self):
    595         q = "select 0 as z, 0 as a, 0 as x, 0 as y"
    596         r = self.c.query(q).fieldname(2)
    597         result = "x"
    598         self.assertEqual(r, result)
    599 
    600     def testFieldnum(self):
    601         q = "select 0 as z, 0 as a, 0 as x, 0 as y"
    602         r = self.c.query(q).fieldnum("x")
    603         result = 2
    604         self.assertEqual(r, result)
    605 
    606     def testNtuples(self):
    607         q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
    608             " union select 5 as a, 6 as b, 7 as c, 8 as d")
    609         r = self.c.query(q).ntuples()
    610         result = 2
    611         self.assertEqual(r, result)
    612 
    613     def testQuery(self):
    614         smart_ddl(self.c, "drop table test_table")
    615         q = "create table test_table (n integer) with oids"
    616         r = self.c.query(q)
    617         self.assertTrue(r is None)
    618         q = "insert into test_table values (1)"
    619         r = self.c.query(q)
    620         self.assertTrue(isinstance(r, int)), r
    621         q = "insert into test_table select 2"
    622         r = self.c.query(q)
    623         self.assertTrue(isinstance(r, int))
    624         oid = r
    625         q = "select oid from test_table where n=2"
    626         r = self.c.query(q).getresult()
    627         self.assertEqual(len(r), 1)
    628         r = r[0]
    629         self.assertEqual(len(r), 1)
    630         r = r[0]
    631         self.assertEqual(r, oid)
    632         q = "insert into test_table select 3 union select 4 union select 5"
    633         r = self.c.query(q)
    634         self.assertTrue(isinstance(r, str))
    635         self.assertEqual(r, '3')
    636         q = "update test_table set n=4 where n<5"
    637         r = self.c.query(q)
    638         self.assertTrue(isinstance(r, str))
    639         self.assertEqual(r, '4')
    640         q = "delete from test_table"
    641         r = self.c.query(q)
    642         self.assertTrue(isinstance(r, str))
    643         self.assertEqual(r, '5')
    644 
    645     def testPrint(self):
    646         import os
    647         q = ("select 1 as a, 'hello' as h, 'w' as world"
    648             " union select 2, 'xyz', 'uvw'")
    649         r = self.c.query(q)
    650         t = '~test_pg_testPrint_temp.tmp'
    651         s = open(t, 'w')
    652         stdout, sys.stdout = sys.stdout, s
    653         try:
    654             print r
    655         except Exception:
    656             pass
    657         sys.stdout = stdout
    658         s.close()
    659         r = filter(bool, open(t, 'r').read().splitlines())
    660         os.remove(t)
    661         self.assertEqual(r,
    662             ['a|  h  |world',
    663              '-+-----+-----',
    664              '1|hello|w    ',
    665              '2|xyz  |uvw  ',
    666              '(2 rows)'])
    667 
    668     def testGetNotify(self):
    669         self.assertTrue(self.c.getnotify() is None)
    670         self.c.query('listen test_notify')
    671         try:
    672             self.assertTrue(self.c.getnotify() is None)
    673             self.c.query("notify test_notify")
    674             r = self.c.getnotify()
    675             self.assertTrue(isinstance(r, tuple))
    676             self.assertEqual(len(r), 3)
    677             self.assertTrue(isinstance(r[0], str))
    678             self.assertTrue(isinstance(r[1], int))
    679             self.assertTrue(isinstance(r[2], str))
    680             self.assertEqual(r[0], 'test_notify')
    681             self.assertEqual(r[2], '')
    682             self.assertTrue(self.c.getnotify() is None)
    683             try:
    684                 self.c.query("notify test_notify, 'test_payload'")
    685             except pg.ProgrammingError:  # PostgreSQL < 9.0
    686                 pass
    687             else:
    688                 r = self.c.getnotify()
    689                 self.assertTrue(isinstance(r, tuple))
    690                 self.assertEqual(len(r), 3)
    691                 self.assertTrue(isinstance(r[0], str))
    692                 self.assertTrue(isinstance(r[1], int))
    693                 self.assertTrue(isinstance(r[2], str))
    694                 self.assertEqual(r[0], 'test_notify')
    695                 self.assertEqual(r[2], 'test_payload')
    696                 self.assertTrue(self.c.getnotify() is None)
    697         finally:
    698             self.c.query('unlisten test_notify')
    699 
    700 
    701 class TestParamQueries(unittest.TestCase):
    702     """"Test queries with parameters via a basic pg connection."""
    703 
    704     # Test database needed: must be run as a DBTestSuite.
    705 
    706     def setUp(self):
    707         dbname = DBTestSuite.dbname
    708         self.c = pg.connect(dbname)
    709 
    710     def tearDown(self):
    711         self.c.query("set client_encoding to UTF8")
    712         self.c.close()
    713 
    714     def testQueryWithNoneParam(self):
    715         self.assertEqual(self.c.query("select $1::integer", (None,)
    716             ).getresult(), [(None,)])
    717         self.assertEqual(self.c.query("select $1::text", [None]
    718             ).getresult(), [(None,)])
    719 
    720     def testQueryWithBoolParams(self):
    721         query = self.c.query
    722         self.assertEqual(query("select false").getresult(), [('f',)])
    723         self.assertEqual(query("select true").getresult(), [('t',)])
    724         self.assertEqual(query("select $1::bool", (None,)).getresult(),
    725             [(None,)])
    726         self.assertEqual(query("select $1::bool", ('f',)).getresult(), [('f',)])
    727         self.assertEqual(query("select $1::bool", ('t',)).getresult(), [('t',)])
    728         self.assertEqual(query("select $1::bool", ('false',)).getresult(),
    729             [('f',)])
    730         self.assertEqual(query("select $1::bool", ('true',)).getresult(),
    731             [('t',)])
    732         self.assertEqual(query("select $1::bool", ('n',)).getresult(), [('f',)])
    733         self.assertEqual(query("select $1::bool", ('y',)).getresult(), [('t',)])
    734         self.assertEqual(query("select $1::bool", (0,)).getresult(), [('f',)])
    735         self.assertEqual(query("select $1::bool", (1,)).getresult(), [('t',)])
    736         self.assertEqual(query("select $1::bool", (False,)).getresult(),
    737             [('f',)])
    738         self.assertEqual(query("select $1::bool", (True,)).getresult(),
    739             [('t',)])
    740 
    741     def testQueryWithIntParams(self):
    742         query = self.c.query
    743         self.assertEqual(query("select 1+1").getresult(), [(2,)])
    744         self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
    745         self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
    746         self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
    747         self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
    748         self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
    749             [(Decimal('2'),)])
    750         self.assertEqual(query("select 1, $1::integer", (2,)
    751             ).getresult(), [(1, 2)])
    752         self.assertEqual(query("select 1 union select $1", (2,)
    753             ).getresult(), [(1,), (2,)])
    754         self.assertEqual(query("select $1::integer+$2", (1, 2)
    755             ).getresult(), [(3,)])
    756         self.assertEqual(query("select $1::integer+$2", [1, 2]
    757             ).getresult(), [(3,)])
    758         self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", range(6)
    759             ).getresult(), [(15,)])
    760 
    761     def testQueryWithStrParams(self):
    762         query = self.c.query
    763         self.assertEqual(query("select $1||', world!'", ('Hello',)
    764             ).getresult(), [('Hello, world!',)])
    765         self.assertEqual(query("select $1||', world!'", ['Hello']
    766             ).getresult(), [('Hello, world!',)])
    767         self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
    768             ).getresult(), [('Hello, world!',)])
    769         self.assertEqual(query("select $1::text", ('Hello, world!',)
    770             ).getresult(), [('Hello, world!',)])
    771         self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
    772             ).getresult(), [('Hello', 'world')])
    773         self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
    774             ).getresult(), [('Hello', 'world')])
    775         self.assertEqual(query("select $1::text union select $2::text",
    776             ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
    777         self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
    778             'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
    779 
    780     def testQueryWithUnicodeParams(self):
    781         query = self.c.query
    782         self.assertEqual(query("select $1||', '||$2||'!'",
    783             ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
    784         self.assertEqual(query("select $1||', '||$2||'!'",
    785             ('Hello', u'\u043c\u0438\u0440')).getresult(),
    786             [('Hello, \xd0\xbc\xd0\xb8\xd1\x80!',)])
    787         query('set client_encoding = latin1')
    788         self.assertEqual(query("select $1||', '||$2||'!'",
    789             ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
    790         self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
    791             ('Hello', u'\u043c\u0438\u0440'))
    792         query('set client_encoding = iso_8859_1')
    793         self.assertEqual(query("select $1||', '||$2||'!'",
    794             ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
    795         self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
    796             ('Hello', u'\u043c\u0438\u0440'))
    797         query('set client_encoding = iso_8859_5')
    798         self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
    799             ('Hello', u'w\xf6rld'))
    800         self.assertEqual(query("select $1||', '||$2||'!'",
    801             ('Hello', u'\u043c\u0438\u0440')).getresult(),
    802             [('Hello, \xdc\xd8\xe0!',)])
    803         query('set client_encoding = sql_ascii')
    804         self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
    805             ('Hello', u'w\xf6rld'))
    806 
    807     def testQueryWithMixedParams(self):
    808         self.assertEqual(self.c.query("select $1+2,$2||', world!'",
    809             (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
    810         self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
    811             (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
    812 
    813     def testQueryWithDuplicateParams(self):
    814         self.assertRaises(pg.ProgrammingError,
    815             self.c.query, "select $1+$1", (1,))
    816         self.assertRaises(pg.ProgrammingError,
    817             self.c.query, "select $1+$1", (1, 2))
    818 
    819     def testQueryWithZeroParams(self):
    820         self.assertEqual(self.c.query("select 1+1", []
    821             ).getresult(), [(2,)])
    822 
    823     def testQueryWithGarbage(self):
    824         garbage = r"'\{}+()-#[]oo324"
    825         self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
    826             ).dictresult(), [{'garbage': garbage}])
    827 
    828     def testUnicodeQuery(self):
    829         query = self.c.query
    830         self.assertEqual(query(u"select 1+1").getresult(), [(2,)])
    831         self.assertRaises(TypeError, query, u"select 'Hello, w\xf6rld!'")
    832 
    833 
    834 class TestInserttable(unittest.TestCase):
    835     """"Test inserttable method."""
    836 
    837     # Test database needed: must be run as a DBTestSuite.
    838 
    839     def setUp(self):
    840         dbname = DBTestSuite.dbname
    841         self.c = pg.connect(dbname)
    842         self.c.query('truncate table test')
    843 
    844     def tearDown(self):
    845         self.c.close()
    846 
    847     def testInserttable1Row(self):
    848         d = Decimal is float and 1.0 or None
    849         data = [(1, 1, 1L, d, 1.0, 1.0, d, "1", "1111", "1")]
    850         self.c.inserttable("test", data)
    851         r = self.c.query("select * from test").getresult()
    852         self.assertEqual(r, data)
    853 
    854     def testInserttable4Rows(self):
    855         data = [(-1, -1, -1L, None, -1.0, -1.0, None, "-1", "-1-1", "-1"),
    856             (0, 0, 0L, None, 0.0, 0.0, None, "0", "0000", "0"),
    857             (1, 1, 1L, None, 1.0, 1.0, None, "1", "1111", "1"),
    858             (2, 2, 2L, None, 2.0, 2.0, None, "2", "2222", "2")]
    859         self.c.inserttable("test", data)
    860         r = self.c.query("select * from test order by 1").getresult()
    861         self.assertEqual(r, data)
    862 
    863     def testInserttableMultipleRows(self):
    864         num_rows = 100
    865         data = [(1, 1, 1L, None, 1.0, 1.0, None, "1", "1111", "1")] * num_rows
    866         self.c.inserttable("test", data)
    867         r = self.c.query("select count(*) from test").getresult()[0][0]
    868         self.assertEqual(r, num_rows)
    869 
    870     def testInserttableMultipleCalls(self):
    871         num_rows = 10
    872         data = [(1, 1, 1L, None, 1.0, 1.0, None, "1", "1111", "1")]
    873         for _i in range(num_rows):
    874             self.c.inserttable("test", data)
    875         r = self.c.query("select count(*) from test").getresult()[0][0]
    876         self.assertEqual(r, num_rows)
    877 
    878     def testInserttableNullValues(self):
    879         data = [(None,) * 10] * 10
    880         self.c.inserttable("test", data)
    881         r = self.c.query("select * from test").getresult()
    882         self.assertEqual(r, data)
    883 
    884     def testInserttableMaxValues(self):
    885         data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
    886             None, 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
    887             "1234", "1234", "1234" * 10)]
    888         self.c.inserttable("test", data)
    889         r = self.c.query("select * from test").getresult()
    890         self.assertEqual(r, data)
    891 
    892 
    893 class TestNoticeReceiver(unittest.TestCase):
    894     """"Test notice receiver support."""
    895 
    896     # Test database needed: must be run as a DBTestSuite.
    897 
    898     def setUp(self):
    899         self.dbname = DBTestSuite.dbname
    900 
    901     def testGetNoticeReceiver(self):
    902         c = pg.connect(self.dbname)
    903         try:
    904             self.assertTrue(c.get_notice_receiver() is None)
    905         finally:
    906             c.close()
    907 
    908     def testSetNoticeReceiver(self):
    909         c = pg.connect(self.dbname)
    910         try:
    911             self.assertRaises(TypeError, c.set_notice_receiver, None)
    912             self.assertRaises(TypeError, c.set_notice_receiver, 42)
    913             self.assertTrue(c.set_notice_receiver(lambda notice: None) is None)
    914         finally:
    915             c.close()
    916 
    917     def testSetandGetNoticeReceiver(self):
    918         c = pg.connect(self.dbname)
    919         try:
    920             r = lambda notice: None
    921             self.assertTrue(c.set_notice_receiver(r) is None)
    922             self.assertTrue(c.get_notice_receiver() is r)
    923         finally:
    924             c.close()
    925 
    926     def testNoticeReceiver(self):
    927         c = pg.connect(self.dbname)
    928         try:
    929             c.query('''create function bilbo_notice() returns void AS $$
    930                 begin
    931                     raise warning 'Bilbo was here!';
    932                 end;
    933                 $$ language plpgsql''')
    934             try:
    935                 received = {}
    936 
    937                 def notice_receiver(notice):
    938                     for attr in dir(notice):
    939                         value = getattr(notice, attr)
    940                         if isinstance(value, str):
    941                             value = value.replace('WARNUNG', 'WARNING')
    942                         received[attr] = value
    943 
    944                 c.set_notice_receiver(notice_receiver)
    945                 c.query('''select bilbo_notice()''')
    946                 self.assertEqual(received, dict(
    947                     pgcnx=c, message='WARNING:  Bilbo was here!\n',
    948                     severity='WARNING', primary='Bilbo was here!',
    949                     detail=None, hint=None))
    950             finally:
    951                 c.query('''drop function bilbo_notice();''')
    952         finally:
    953             c.close()
     35    from LOCAL_PyGreSQL import *
     36except ImportError:
     37    pass
     38
     39
     40def DB():
     41    """Create a DB wrapper object connecting to the test database."""
     42    db = pg.DB(dbname, dbhost, dbport)
     43    if debug:
     44        db.debug = debug
     45    db.query("set client_min_messages=warning")
     46    return db
    95447
    95548
     
    95851
    95952    def setUp(self):
    960         dbname = 'template1'
    961         self.dbname = dbname
    962         self.db = pg.DB(dbname)
     53        self.db = DB()
    96354
    96455    def tearDown(self):
    965         self.db.close()
     56        try:
     57            self.db.close()
     58        except pg.InternalError:
     59            pass
    96660
    96761    def testAllDBAttributes(self):
     
    1029123
    1030124    def testAttributeDb(self):
    1031         self.assertEqual(self.db.db.db, self.dbname)
     125        self.assertEqual(self.db.db.db, dbname)
    1032126
    1033127    def testAttributeDbname(self):
    1034         self.assertEqual(self.db.dbname, self.dbname)
     128        self.assertEqual(self.db.dbname, dbname)
    1035129
    1036130    def testAttributeError(self):
     
    1089183
    1090184    def testMethodEscapeLiteral(self):
    1091         self.assertEqual(self.db.escape_literal("plain"), "'plain'")
    1092         self.assertEqual(self.db.escape_literal(
    1093             "that's k\xe4se"), "'that''s k\xe4se'")
    1094         self.assertEqual(self.db.escape_literal(
    1095             r"It's fine to have a \ inside."),
    1096             r" E'It''s fine to have a \\ inside.'")
    1097         self.assertEqual(self.db.escape_literal(
    1098             'No "quotes" must be escaped.'),
    1099             "'No \"quotes\" must be escaped.'")
     185        self.assertEqual(self.db.escape_literal(''), "''")
    1100186
    1101187    def testMethodEscapeIdentifier(self):
    1102         self.assertEqual(self.db.escape_identifier("plain"), '"plain"')
    1103         self.assertEqual(self.db.escape_identifier(
    1104             "that's k\xe4se"), '"that\'s k\xe4se"')
    1105         self.assertEqual(self.db.escape_identifier(
    1106             r"It's fine to have a \ inside."),
    1107             '"It\'s fine to have a \\ inside."')
    1108         self.assertEqual(self.db.escape_identifier(
    1109             'All "quotes" must be escaped.'),
    1110             '"All ""quotes"" must be escaped."')
     188        self.assertEqual(self.db.escape_identifier(''), '""')
    1111189
    1112190    def testMethodEscapeString(self):
    1113         standard_conforming = self.db.query(
    1114             "show standard_conforming_strings").getresult()[0][0]
    1115         self.assertTrue(standard_conforming in ('on', 'off'))
    1116         self.assertEqual(self.db.escape_string("plain"), "plain")
    1117         self.assertEqual(self.db.escape_string(
    1118             "that's k\xe4se"), "that''s k\xe4se")
    1119         if standard_conforming == 'on':
    1120             self.assertEqual(self.db.escape_string(
    1121                 r"It's fine to have a \ inside."),
    1122                 r"It''s fine to have a \ inside.")
    1123         else:
    1124             self.assertEqual(self.db.escape_string(
    1125                 r"It's fine to have a \ inside."),
    1126                 r"It''s fine to have a \\ inside.")
     191        self.assertEqual(self.db.escape_string(''), '')
    1127192
    1128193    def testMethodEscapeBytea(self):
    1129         output = self.db.query("show bytea_output").getresult()[0][0]
    1130         if output == 'escape' and self.db.escape_bytea("plain") != 'plain':
    1131             output = 'hex'
    1132         self.assertTrue(output in ('escape', 'hex'))
    1133         standard_conforming = self.db.query(
    1134             "show standard_conforming_strings").getresult()[0][0]
    1135         self.assertTrue(standard_conforming in ('on', 'off'))
    1136         if output == 'escape':
    1137             self.assertEqual(self.db.escape_bytea("plain"), "plain")
    1138             self.assertEqual(self.db.escape_bytea(
    1139                 "that's k\xe4se"), "that''s k\\\\344se")
    1140             self.assertEqual(self.db.escape_bytea(
    1141                 'O\x00ps\xff!'), r'O\\000ps\\377!')
    1142         elif standard_conforming == 'on':
    1143             self.assertEqual(self.db.escape_bytea("plain"), r"\x706c61696e")
    1144             self.assertEqual(self.db.escape_bytea(
    1145                 "that's k\xe4se"), r"\x746861742773206be47365")
    1146             self.assertEqual(self.db.escape_bytea(
    1147                 'O\x00ps\xff!'), r"\x4f007073ff21")
    1148         else:
    1149             self.assertEqual(self.db.escape_bytea("plain"), r"\\x706c61696e")
    1150             self.assertEqual(self.db.escape_bytea(
    1151                 "that's k\xe4se"), r"\\x746861742773206be47365")
    1152             self.assertEqual(self.db.escape_bytea(
    1153                 'O\x00ps\xff!'), r"\\x4f007073ff21")
     194        self.assertEqual(self.db.escape_bytea('').replace(
     195            '\\x', '').replace('\\', ''), '')
    1154196
    1155197    def testMethodUnescapeBytea(self):
    1156         standard_conforming = self.db.query(
    1157             "show standard_conforming_strings").getresult()[0][0]
    1158         self.assertTrue(standard_conforming in ('on', 'off'))
    1159         self.assertEqual(self.db.unescape_bytea("plain"), "plain")
    1160         self.assertEqual(self.db.unescape_bytea(
    1161             "that's k\\344se"), "that's k\xe4se")
    1162         self.assertEqual(pg.unescape_bytea(
    1163             r'O\000ps\377!'), 'O\x00ps\xff!')
    1164         self.assertEqual(self.db.unescape_bytea(r"\x706c61696e"), "plain")
    1165         self.assertEqual(self.db.unescape_bytea(
    1166             r"\x746861742773206be47365"), "that's k\xe4se")
    1167         self.assertEqual(pg.unescape_bytea(
    1168             r"\x4f007073ff21"), 'O\x00ps\xff!')
     198        self.assertEqual(self.db.unescape_bytea(''), '')
    1169199
    1170200    def testMethodQuery(self):
    1171         self.db.query("select 1+1")
    1172         self.db.query("select 1+$1", 1)
    1173         self.db.query("select 1+$1+$2", 2, 3)
    1174         self.db.query("select 1+$1+$2", (2, 3))
    1175         self.db.query("select 1+$1+$2", [2, 3])
     201        query = self.db.query
     202        query("select 1+1")
     203        query("select 1+$1+$2", 2, 3)
     204        query("select 1+$1+$2", (2, 3))
     205        query("select 1+$1+$2", [2, 3])
     206        query("select 1+$1", 1)
    1176207
    1177208    def testMethodQueryEmpty(self):
     
    1200231        self.assertRaises(pg.InternalError, self.db.close)
    1201232        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
    1202         self.db = pg.DB(self.dbname)
    1203233
    1204234    def testExistingConnection(self):
     
    1229259    """"Test the methods of the DB class wrapped pg connection."""
    1230260
    1231     # Test database needed: must be run as a DBTestSuite.
     261    @classmethod
     262    def setUpClass(cls):
     263        db = DB()
     264        db.query("drop table if exists test cascade")
     265        db.query("create table test ("
     266            "i2 smallint, i4 integer, i8 bigint,"
     267            "d numeric, f4 real, f8 double precision, m money, "
     268            "v4 varchar(4), c4 char(4), t text)")
     269        db.query("create or replace view test_view as"
     270            " select i4, v4 from test")
     271        db.close()
     272
     273    @classmethod
     274    def tearDownClass(cls):
     275        db = DB()
     276        db.query("drop table test cascade")
     277        db.close()
    1232278
    1233279    def setUp(self):
    1234         dbname = DBTestSuite.dbname
    1235         self.dbname = dbname
    1236         self.db = pg.DB(dbname)
    1237         self.db.query("set lc_monetary='C'");
    1238         if debug:
    1239             self.db.debug = 'DEBUG: %s'
     280        self.db = DB()
     281        self.db.query("set lc_monetary='C'")
     282        self.db.query('set bytea_output=hex')
     283        self.db.query('set standard_conforming_strings=on')
    1240284
    1241285    def tearDown(self):
    1242286        self.db.close()
    1243287
     288    def testEscapeLiteral(self):
     289        f = self.db.escape_literal
     290        self.assertEqual(f("plain"), "'plain'")
     291        self.assertEqual(f("that's k\xe4se"), "'that''s k\xe4se'")
     292        self.assertEqual(f(r"It's fine to have a \ inside."),
     293            r" E'It''s fine to have a \\ inside.'")
     294        self.assertEqual(f('No "quotes" must be escaped.'),
     295            "'No \"quotes\" must be escaped.'")
     296
     297    def testEscapeIdentifier(self):
     298        f = self.db.escape_identifier
     299        self.assertEqual(f("plain"), '"plain"')
     300        self.assertEqual(f("that's k\xe4se"), '"that\'s k\xe4se"')
     301        self.assertEqual(f(r"It's fine to have a \ inside."),
     302            '"It\'s fine to have a \\ inside."')
     303        self.assertEqual(f('All "quotes" must be escaped.'),
     304            '"All ""quotes"" must be escaped."')
     305
    1244306    def testEscapeString(self):
    1245         self.assertEqual(self.db.escape_string("plain"), "plain")
    1246         self.assertEqual(self.db.escape_string(
    1247             "that's k\xe4se"), "that''s k\xe4se")
    1248         self.assertEqual(self.db.escape_string(
    1249             r"It's fine to have a \ inside."),
    1250             r"It''s fine to have a \\ inside.")
     307        f = self.db.escape_string
     308        self.assertEqual(f("plain"), "plain")
     309        self.assertEqual(f("that's k\xe4se"), "that''s k\xe4se")
     310        self.assertEqual(f(r"It's fine to have a \ inside."),
     311            r"It''s fine to have a \ inside.")
    1251312
    1252313    def testEscapeBytea(self):
    1253         output = self.db.query("show bytea_output").getresult()[0][0]
    1254         if output == 'escape' and self.db.escape_bytea("plain") != 'plain':
    1255             output = 'hex'
    1256         self.assertTrue(output in ('escape', 'hex'))
    1257         standard_conforming = self.db.query(
    1258             "show standard_conforming_strings").getresult()[0][0]
    1259         self.assertTrue(standard_conforming in ('on', 'off'))
    1260         if output == 'escape':
    1261             self.assertEqual(self.db.escape_bytea("plain"), "plain")
    1262             self.assertEqual(self.db.escape_bytea(
    1263                 "that's k\xe4se"), "that''s k\\\\344se")
    1264             self.assertEqual(self.db.escape_bytea(
    1265                 'O\x00ps\xff!'), r'O\\000ps\\377!')
    1266         elif standard_conforming == 'on':
    1267             self.assertEqual(self.db.escape_bytea("plain"), r"\x706c61696e")
    1268             self.assertEqual(self.db.escape_bytea(
    1269                 "that's k\xe4se"), r"\x746861742773206be47365")
    1270             self.assertEqual(self.db.escape_bytea(
    1271                 'O\x00ps\xff!'), r"\x4f007073ff21")
    1272         else:
    1273             self.assertEqual(self.db.escape_bytea("plain"), r"\\x706c61696e")
    1274             self.assertEqual(self.db.escape_bytea(
    1275                 "that's k\xe4se"), r"\\x746861742773206be47365")
    1276             self.assertEqual(self.db.escape_bytea(
    1277                 'O\x00ps\xff!'), r"\\x4f007073ff21")
     314        f = self.db.escape_bytea
     315        # note that escape_byte always returns hex output since Pg 9.0,
     316        # regardless of the bytea_output setting
     317        self.assertEqual(f("plain"), r"\x706c61696e")
     318        self.assertEqual(f("that's k\xe4se"), r"\x746861742773206be47365")
     319        self.assertEqual(f('O\x00ps\xff!'), r"\x4f007073ff21")
    1278320
    1279321    def testUnescapeBytea(self):
    1280         standard_conforming = self.db.query(
    1281             "show standard_conforming_strings").getresult()[0][0]
    1282         self.assertTrue(standard_conforming in ('on', 'off'))
    1283         self.assertEqual(self.db.unescape_bytea("plain"), "plain")
    1284         self.assertEqual(self.db.unescape_bytea(
    1285             "that's k\\344se"), "that's k\xe4se")
    1286         self.assertEqual(pg.unescape_bytea(
    1287             r'O\000ps\377!'), 'O\x00ps\xff!')
    1288         if standard_conforming == 'on':
    1289             self.assertEqual(self.db.unescape_bytea(r"\\x706c61696e"), "plain")
    1290             self.assertEqual(self.db.unescape_bytea(
    1291                 r"\\x746861742773206be47365"), "that's k\xe4se")
    1292             self.assertEqual(pg.unescape_bytea(
    1293                 r"\\x4f007073ff21"), 'O\x00ps\xff!')
    1294         else:
    1295             self.assertEqual(self.db.unescape_bytea(r"\x706c61696e"), "plain")
    1296             self.assertEqual(self.db.unescape_bytea(
    1297                 r"\x746861742773206be47365"), "that's k\xe4se")
    1298             self.assertEqual(pg.unescape_bytea(
    1299                 r"\x4f007073ff21"), 'O\x00ps\xff!')
     322        f = self.db.unescape_bytea
     323        self.assertEqual(f("plain"), "plain")
     324        self.assertEqual(f("that's k\\344se"), "that's k\xe4se")
     325        self.assertEqual(f(r'O\000ps\377!'), 'O\x00ps\xff!')
     326        self.assertEqual(f(r"\\x706c61696e"), r"\x706c61696e")
     327        self.assertEqual(f(r"\\x746861742773206be47365"),
     328            r"\x746861742773206be47365")
     329        self.assertEqual(f(r"\\x4f007073ff21"), r"\x4f007073ff21")
    1300330
    1301331    def testQuote(self):
     
    1367397        self.assertEqual(f('abc', 'text'), "'abc'")
    1368398        self.assertEqual(f("ab'c", 'text'), "'ab''c'")
     399        self.assertEqual(f('ab\\c', 'text'), "'ab\\c'")
     400        self.assertEqual(f("a\\b'c", 'text'), "'a\\b''c'")
     401        self.db.query('set standard_conforming_strings=off')
    1369402        self.assertEqual(f('ab\\c', 'text'), "'ab\\\\c'")
    1370403        self.assertEqual(f("a\\b'c", 'text'), "'a\\\\b''c'")
    1371404
    1372405    def testQuery(self):
    1373         smart_ddl(self.db, "drop table test_table")
     406        query = self.db.query
     407        query("drop table if exists test_table")
    1374408        q = "create table test_table (n integer) with oids"
    1375         r = self.db.query(q)
    1376         self.assertTrue(r is None)
     409        r = query(q)
     410        self.assertIsNone(r)
    1377411        q = "insert into test_table values (1)"
    1378         r = self.db.query(q)
    1379         self.assertTrue(isinstance(r, int))
     412        r = query(q)
     413        self.assertIsInstance(r, int)
    1380414        q = "insert into test_table select 2"
    1381         r = self.db.query(q)
    1382         self.assertTrue(isinstance(r, int))
     415        r = query(q)
     416        self.assertIsInstance(r, int)
    1383417        oid = r
    1384418        q = "select oid from test_table where n=2"
    1385         r = self.db.query(q).getresult()
     419        r = query(q).getresult()
    1386420        self.assertEqual(len(r), 1)
    1387421        r = r[0]
     
    1390424        self.assertEqual(r, oid)
    1391425        q = "insert into test_table select 3 union select 4 union select 5"
    1392         r = self.db.query(q)
    1393         self.assertTrue(isinstance(r, str))
     426        r = query(q)
     427        self.assertIsInstance(r, str)
    1394428        self.assertEqual(r, '3')
    1395429        q = "update test_table set n=4 where n<5"
    1396         r = self.db.query(q)
    1397         self.assertTrue(isinstance(r, str))
     430        r = query(q)
     431        self.assertIsInstance(r, str)
    1398432        self.assertEqual(r, '4')
    1399433        q = "delete from test_table"
    1400         r = self.db.query(q)
    1401         self.assertTrue(isinstance(r, str))
     434        r = query(q)
     435        self.assertIsInstance(r, str)
    1402436        self.assertEqual(r, '5')
     437        query("drop table test_table")
    1403438
    1404439    def testMultipleQueries(self):
     
    1409444
    1410445    def testQueryWithParams(self):
    1411         smart_ddl(self.db, "drop table test_table")
     446        query = self.db.query
     447        query("drop table if exists test_table")
    1412448        q = "create table test_table (n1 integer, n2 integer) with oids"
    1413         self.db.query(q)
     449        query(q)
    1414450        q = "insert into test_table values ($1, $2)"
    1415         r = self.db.query(q, (1, 2))
    1416         self.assertTrue(isinstance(r, int))
    1417         r = self.db.query(q, [3, 4])
    1418         self.assertTrue(isinstance(r, int))
    1419         r = self.db.query(q, [5, 6])
    1420         self.assertTrue(isinstance(r, int))
     451        r = query(q, (1, 2))
     452        self.assertIsInstance(r, int)
     453        r = query(q, [3, 4])
     454        self.assertIsInstance(r, int)
     455        r = query(q, [5, 6])
     456        self.assertIsInstance(r, int)
    1421457        q = "select * from test_table order by 1, 2"
    1422         self.assertEqual(self.db.query(q).getresult(),
     458        self.assertEqual(query(q).getresult(),
    1423459            [(1, 2), (3, 4), (5, 6)])
    1424460        q = "select * from test_table where n1=$1 and n2=$2"
    1425         self.assertEqual(self.db.query(q, 3, 4).getresult(), [(3, 4)])
     461        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
    1426462        q = "update test_table set n2=$2 where n1=$1"
    1427         r = self.db.query(q, 3, 7)
     463        r = query(q, 3, 7)
    1428464        self.assertEqual(r, '1')
    1429465        q = "select * from test_table order by 1, 2"
    1430         self.assertEqual(self.db.query(q).getresult(),
     466        self.assertEqual(query(q).getresult(),
    1431467            [(1, 2), (3, 7), (5, 6)])
    1432468        q = "delete from test_table where n2!=$1"
    1433         r = self.db.query(q, 4)
     469        r = query(q, 4)
    1434470        self.assertEqual(r, '3')
     471        query("drop table test_table")
    1435472
    1436473    def testEmptyQuery(self):
     
    1444481
    1445482    def testPkey(self):
    1446         smart_ddl(self.db, "drop table pkeytest0")
    1447         smart_ddl(self.db, "create table pkeytest0 ("
     483        query = self.db.query
     484        for n in range(4):
     485            query("drop table if exists pkeytest%d" % n)
     486        query("create table pkeytest0 ("
    1448487            "a smallint)")
    1449         smart_ddl(self.db, 'drop table pkeytest1')
    1450         smart_ddl(self.db, "create table pkeytest1 ("
     488        query("create table pkeytest1 ("
    1451489            "b smallint primary key)")
    1452         smart_ddl(self.db, 'drop table pkeytest2')
    1453         smart_ddl(self.db, "create table pkeytest2 ("
     490        query("create table pkeytest2 ("
    1454491            "c smallint, d smallint primary key)")
    1455         smart_ddl(self.db, "drop table pkeytest3")
    1456         smart_ddl(self.db, "create table pkeytest3 ("
     492        query("create table pkeytest3 ("
    1457493            "e smallint, f smallint, g smallint, "
    1458494            "h smallint, i smallint, "
    1459495            "primary key (f,h))")
    1460         self.assertRaises(KeyError, self.db.pkey, 'pkeytest0')
    1461         self.assertEqual(self.db.pkey('pkeytest1'), 'b')
    1462         self.assertEqual(self.db.pkey('pkeytest2'), 'd')
    1463         self.assertEqual(self.db.pkey('pkeytest3'), frozenset('fh'))
    1464         self.assertEqual(self.db.pkey('pkeytest0', 'none'), 'none')
    1465         self.assertEqual(self.db.pkey('pkeytest0'), 'none')
    1466         self.db.pkey(None, {'t': 'a', 'n.t': 'b'})
    1467         self.assertEqual(self.db.pkey('t'), 'a')
    1468         self.assertEqual(self.db.pkey('n.t'), 'b')
    1469         self.assertRaises(KeyError, self.db.pkey, 'pkeytest0')
     496        pkey = self.db.pkey
     497        self.assertRaises(KeyError, pkey, 'pkeytest0')
     498        self.assertEqual(pkey('pkeytest1'), 'b')
     499        self.assertEqual(pkey('pkeytest2'), 'd')
     500        self.assertEqual(pkey('pkeytest3'), frozenset('fh'))
     501        self.assertEqual(pkey('pkeytest0', 'none'), 'none')
     502        self.assertEqual(pkey('pkeytest0'), 'none')
     503        pkey(None, {'t': 'a', 'n.t': 'b'})
     504        self.assertEqual(pkey('t'), 'a')
     505        self.assertEqual(pkey('n.t'), 'b')
     506        self.assertRaises(KeyError, pkey, 'pkeytest0')
     507        for n in range(4):
     508            query("drop table pkeytest%d" % n)
    1470509
    1471510    def testGetDatabases(self):
    1472511        databases = self.db.get_databases()
    1473         self.assertTrue('template0' in databases)
    1474         self.assertTrue('template1' in databases)
    1475         self.assertTrue(self.dbname in databases)
     512        self.assertIn('template0', databases)
     513        self.assertIn('template1', databases)
     514        self.assertNotIn('not existing database', databases)
     515        self.assertIn('postgres', databases)
     516        self.assertIn(dbname, databases)
    1476517
    1477518    def testGetTables(self):
    1478         result1 = self.db.get_tables()
     519        get_tables = self.db.get_tables
     520        result1 = get_tables()
    1479521        tables = ('"A very Special Name"',
    1480522            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
     
    1483525            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
    1484526        for t in tables:
    1485             smart_ddl(self.db, 'drop table ' + t)
    1486             smart_ddl(self.db, "create table %s"
     527            self.db.query('drop table if exists %s' % t)
     528            self.db.query("create table %s"
    1487529                " as select 0" % t)
    1488         result3 = self.db.get_tables()
     530        result3 = get_tables()
    1489531        result2 = []
    1490532        for t in result3:
     
    1498540        self.assertEqual(result2, result3)
    1499541        for t in result2:
    1500             self.db.query('drop table ' + t)
    1501         result2 = self.db.get_tables()
     542            self.db.query('drop table %s' % t)
     543        result2 = get_tables()
    1502544        self.assertEqual(result2, result1)
    1503545
    1504546    def testGetRelations(self):
    1505         result = self.db.get_relations()
    1506         self.assertTrue('public.test' in result)
    1507         self.assertTrue('public.test_view' in result)
    1508         result = self.db.get_relations('rv')
    1509         self.assertTrue('public.test' in result)
    1510         self.assertTrue('public.test_view' in result)
    1511         result = self.db.get_relations('r')
    1512         self.assertTrue('public.test' in result)
    1513         self.assertTrue('public.test_view' not in result)
    1514         result = self.db.get_relations('v')
    1515         self.assertTrue('public.test' not in result)
    1516         self.assertTrue('public.test_view' in result)
    1517         result = self.db.get_relations('cisSt')
    1518         self.assertTrue('public.test' not in result)
    1519         self.assertTrue('public.test_view' not in result)
     547        get_relations = self.db.get_relations
     548        result = get_relations()
     549        self.assertIn('public.test', result)
     550        self.assertIn('public.test_view', result)
     551        result = get_relations('rv')
     552        self.assertIn('public.test', result)
     553        self.assertIn('public.test_view', result)
     554        result = get_relations('r')
     555        self.assertIn('public.test', result)
     556        self.assertNotIn('public.test_view', result)
     557        result = get_relations('v')
     558        self.assertNotIn('public.test', result)
     559        self.assertIn('public.test_view', result)
     560        result = get_relations('cisSt')
     561        self.assertNotIn('public.test', result)
     562        self.assertNotIn('public.test_view', result)
    1520563
    1521564    def testAttnames(self):
     
    1525568            self.db.get_attnames, 'has.too.many.dots')
    1526569        for table in ('attnames_test_table', 'test table for attnames'):
    1527             smart_ddl(self.db, 'drop table "%s"' % table)
    1528             smart_ddl(self.db, 'create table "%s" ('
     570            self.db.query('drop table if exists "%s"' % table)
     571            self.db.query('create table "%s" ('
    1529572                'a smallint, b integer, c bigint, '
    1530573                'e numeric, f float, f2 double precision, m money, '
     
    1532575                'Normal_NaMe smallint, "Special Name" smallint, '
    1533576                't text, u char(2), v varchar(2), '
    1534                 'primary key (y, u))' % table)
     577                'primary key (y, u)) with oids' % table)
    1535578            attributes = self.db.get_attnames(table)
    1536579            result = {'a': 'int', 'c': 'int', 'b': 'int',
     
    1540583                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
    1541584            self.assertEqual(attributes, result)
     585            self.db.query('drop table "%s"' % table)
    1542586
    1543587    def testHasTablePrivilege(self):
     
    1554598
    1555599    def testGet(self):
     600        get = self.db.get
     601        query = self.db.query
    1556602        for table in ('get_test_table', 'test table for get'):
    1557             smart_ddl(self.db, 'drop table "%s"' % table)
    1558             smart_ddl(self.db, 'create table "%s" ('
    1559                 "n integer, t text)" % table)
     603            query('drop table if exists "%s"' % table)
     604            query('create table "%s" ('
     605                "n integer, t text) with oids" % table)
    1560606            for n, t in enumerate('xyz'):
    1561                 self.db.query('insert into "%s" values('
    1562                     "%d, '%s')" % (table, n + 1, t))
    1563             self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
    1564             r = self.db.get(table, 2, 'n')
     607                query('insert into "%s" values('"%d, '%s')"
     608                    % (table, n + 1, t))
     609            self.assertRaises(pg.ProgrammingError, get, table, 2)
     610            r = get(table, 2, 'n')
    1565611            oid_table = table
    1566612            if ' ' in table:
    1567613                oid_table = '"%s"' % oid_table
    1568614            oid_table = 'oid(public.%s)' % oid_table
    1569             self.assertTrue(oid_table in r)
     615            self.assertIn(oid_table, r)
    1570616            oid = r[oid_table]
    1571             self.assertTrue(isinstance(oid, int))
     617            self.assertIsInstance(oid, int)
    1572618            result = {'t': 'y', 'n': 2, oid_table: oid}
    1573619            self.assertEqual(r, result)
    1574             self.assertEqual(self.db.get(table + ' *', 2, 'n'), r)
    1575             self.assertEqual(self.db.get(table, oid, 'oid')['t'], 'y')
    1576             self.assertEqual(self.db.get(table, 1, 'n')['t'], 'x')
    1577             self.assertEqual(self.db.get(table, 3, 'n')['t'], 'z')
    1578             self.assertEqual(self.db.get(table, 2, 'n')['t'], 'y')
    1579             self.assertRaises(pg.DatabaseError, self.db.get, table, 4, 'n')
     620            self.assertEqual(get(table + ' *', 2, 'n'), r)
     621            self.assertEqual(get(table, oid, 'oid')['t'], 'y')
     622            self.assertEqual(get(table, 1, 'n')['t'], 'x')
     623            self.assertEqual(get(table, 3, 'n')['t'], 'z')
     624            self.assertEqual(get(table, 2, 'n')['t'], 'y')
     625            self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
    1580626            r['n'] = 3
    1581             self.assertEqual(self.db.get(table, r, 'n')['t'], 'z')
    1582             self.assertEqual(self.db.get(table, 1, 'n')['t'], 'x')
    1583             self.db.query('alter table "%s" alter n set not null' % table)
    1584             self.db.query('alter table "%s" add primary key (n)' % table)
    1585             self.assertEqual(self.db.get(table, 3)['t'], 'z')
    1586             self.assertEqual(self.db.get(table, 1)['t'], 'x')
    1587             self.assertEqual(self.db.get(table, 2)['t'], 'y')
     627            self.assertEqual(get(table, r, 'n')['t'], 'z')
     628            self.assertEqual(get(table, 1, 'n')['t'], 'x')
     629            query('alter table "%s" alter n set not null' % table)
     630            query('alter table "%s" add primary key (n)' % table)
     631            self.assertEqual(get(table, 3)['t'], 'z')
     632            self.assertEqual(get(table, 1)['t'], 'x')
     633            self.assertEqual(get(table, 2)['t'], 'y')
    1588634            r['n'] = 1
    1589             self.assertEqual(self.db.get(table, r)['t'], 'x')
     635            self.assertEqual(get(table, r)['t'], 'x')
    1590636            r['n'] = 3
    1591             self.assertEqual(self.db.get(table, r)['t'], 'z')
     637            self.assertEqual(get(table, r)['t'], 'z')
    1592638            r['n'] = 2
    1593             self.assertEqual(self.db.get(table, r)['t'], 'y')
     639            self.assertEqual(get(table, r)['t'], 'y')
     640            query('drop table "%s"' % table)
    1594641
    1595642    def testGetWithCompositeKey(self):
     643        get = self.db.get
     644        query = self.db.query
    1596645        table = 'get_test_table_1'
    1597         smart_ddl(self.db, "drop table %s" % table)
    1598         smart_ddl(self.db, "create table %s ("
     646        query("drop table if exists %s" % table)
     647        query("create table %s ("
    1599648            "n integer, t text, primary key (n))" % table)
    1600649        for n, t in enumerate('abc'):
    1601             self.db.query("insert into %s values("
     650            query("insert into %s values("
    1602651                "%d, '%s')" % (table, n + 1, t))
    1603         self.assertEqual(self.db.get(table, 2)['t'], 'b')
     652        self.assertEqual(get(table, 2)['t'], 'b')
     653        query("drop table %s" % table)
    1604654        table = 'get_test_table_2'
    1605         smart_ddl(self.db, "drop table %s" % table)
    1606         smart_ddl(self.db, "create table %s ("
     655        query("drop table if exists %s" % table)
     656        query("create table %s ("
    1607657            "n integer, m integer, t text, primary key (n, m))" % table)
    1608658        for n in range(3):
    1609659            for m in range(2):
    1610660                t = chr(ord('a') + 2 * n + m)
    1611                 self.db.query("insert into %s values("
     661                query("insert into %s values("
    1612662                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
    1613         self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
    1614         self.assertEqual(self.db.get(table, dict(n=2, m=2))['t'], 'd')
    1615         self.assertEqual(self.db.get(table, dict(n=1, m=2),
    1616             ('n', 'm'))['t'], 'b')
    1617         self.assertEqual(self.db.get(table, dict(n=3, m=2),
    1618             frozenset(['n', 'm']))['t'], 'f')
     663        self.assertRaises(pg.ProgrammingError, get, table, 2)
     664        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
     665        self.assertEqual(get(table, dict(n=1, m=2),
     666                             ('n', 'm'))['t'], 'b')
     667        self.assertEqual(get(table, dict(n=3, m=2),
     668                             frozenset(['n', 'm']))['t'], 'f')
     669        query("drop table %s" % table)
    1619670
    1620671    def testGetFromView(self):
     
    1623674            "14, 'abc4')")
    1624675        r = self.db.get('test_view', 14, 'i4')
    1625         self.assertTrue('v4' in r)
     676        self.assertIn('v4', r)
    1626677        self.assertEqual(r['v4'], 'abc4')
    1627678
    1628679    def testInsert(self):
     680        insert = self.db.insert
     681        query = self.db.query
    1629682        for table in ('insert_test_table', 'test table for insert'):
    1630             smart_ddl(self.db, 'drop table "%s"' % table)
    1631             smart_ddl(self.db, 'create table "%s" ('
     683            query('drop table if exists "%s"' % table)
     684            query('create table "%s" ('
    1632685                "i2 smallint, i4 integer, i8 bigint,"
    1633686                " d numeric, f4 real, f8 double precision, m money,"
    1634687                " v4 varchar(4), c4 char(4), t text,"
    1635                 " b boolean, ts timestamp)" % table)
     688                " b boolean, ts timestamp) with oids" % table)
    1636689            oid_table = table
    1637690            if ' ' in table:
     
    1688741                expect = data.copy()
    1689742                expect.update(change)
    1690                 self.assertEqual(self.db.insert(table, data), data)
    1691                 self.assertTrue(oid_table in data)
     743                self.assertEqual(insert(table, data), data)
     744                self.assertIn(oid_table, data)
    1692745                oid = data[oid_table]
    1693                 self.assertTrue(isinstance(oid, int))
     746                self.assertIsInstance(oid, int)
    1694747                data = dict(item for item in data.iteritems()
    1695748                    if item[0] in expect)
     
    1708761                    self.assertEqual(ts[13], ':')
    1709762                self.assertEqual(data, expect)
    1710                 data = self.db.query(
     763                data = query(
    1711764                    'select oid,* from "%s"' % table).dictresult()[0]
    1712765                self.assertEqual(data['oid'], oid)
     
    1714767                    if item[0] in expect)
    1715768                self.assertEqual(data, expect)
    1716                 self.db.query('delete from "%s"' % table)
     769                query('delete from "%s"' % table)
     770            query('drop table "%s"' % table)
    1717771
    1718772    def testUpdate(self):
     773        update = self.db.update
     774        query = self.db.query
    1719775        for table in ('update_test_table', 'test table for update'):
    1720             smart_ddl(self.db, 'drop table "%s"' % table)
    1721             smart_ddl(self.db, 'create table "%s" ('
    1722                 "n integer, t text)" % table)
     776            query('drop table if exists "%s"' % table)
     777            query('create table "%s" ('
     778                "n integer, t text) with oids" % table)
    1723779            for n, t in enumerate('xyz'):
    1724                 self.db.query('insert into "%s" values('
     780                query('insert into "%s" values('
    1725781                    "%d, '%s')" % (table, n + 1, t))
    1726782            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
    1727783            r = self.db.get(table, 2, 'n')
    1728784            r['t'] = 'u'
    1729             s = self.db.update(table, r)
     785            s = update(table, r)
    1730786            self.assertEqual(s, r)
    1731             r = self.db.query('select t from "%s" where n=2' % table
    1732                 ).getresult()[0][0]
     787            r = query('select t from "%s" where n=2' % table
     788                      ).getresult()[0][0]
    1733789            self.assertEqual(r, 'u')
     790            query('drop table "%s"' % table)
    1734791
    1735792    def testUpdateWithCompositeKey(self):
     793        update = self.db.update
     794        query = self.db.query
    1736795        table = 'update_test_table_1'
    1737         smart_ddl(self.db, "drop table %s" % table)
    1738         smart_ddl(self.db, "create table %s ("
     796        query("drop table if exists %s" % table)
     797        query("create table %s ("
    1739798            "n integer, t text, primary key (n))" % table)
    1740799        for n, t in enumerate('abc'):
    1741             self.db.query("insert into %s values("
     800            query("insert into %s values("
    1742801                "%d, '%s')" % (table, n + 1, t))
    1743         self.assertRaises(pg.ProgrammingError, self.db.update,
    1744             table, dict(t='b'))
    1745         self.assertEqual(self.db.update(table, dict(n=2, t='d'))['t'], 'd')
    1746         r = self.db.query('select t from "%s" where n=2' % table
    1747             ).getresult()[0][0]
     802        self.assertRaises(pg.ProgrammingError, update,
     803                          table, dict(t='b'))
     804        self.assertEqual(update(table, dict(n=2, t='d'))['t'], 'd')
     805        r = query('select t from "%s" where n=2' % table
     806                  ).getresult()[0][0]
    1748807        self.assertEqual(r, 'd')
     808        query("drop table %s" % table)
    1749809        table = 'update_test_table_2'
    1750         smart_ddl(self.db, "drop table %s" % table)
    1751         smart_ddl(self.db, "create table %s ("
     810        query("drop table if exists %s" % table)
     811        query("create table %s ("
    1752812            "n integer, m integer, t text, primary key (n, m))" % table)
    1753813        for n in range(3):
    1754814            for m in range(2):
    1755815                t = chr(ord('a') + 2 * n + m)
    1756                 self.db.query("insert into %s values("
     816                query("insert into %s values("
    1757817                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
    1758         self.assertRaises(pg.ProgrammingError, self.db.update,
    1759             table, dict(n=2, t='b'))
    1760         self.assertEqual(self.db.update(table,
    1761             dict(n=2, m=2, t='x'))['t'], 'x')
    1762         r = [r[0] for r in self.db.query('select t from "%s" where n=2'
     818        self.assertRaises(pg.ProgrammingError, update,
     819                          table, dict(n=2, t='b'))
     820        self.assertEqual(update(table,
     821                                dict(n=2, m=2, t='x'))['t'], 'x')
     822        r = [r[0] for r in query('select t from "%s" where n=2'
    1763823            ' order by m' % table).getresult()]
    1764824        self.assertEqual(r, ['c', 'x'])
     825        query("drop table %s" % table)
    1765826
    1766827    def testClear(self):
     828        clear = self.db.clear
     829        query = self.db.query
    1767830        for table in ('clear_test_table', 'test table for clear'):
    1768             smart_ddl(self.db, 'drop table "%s"' % table)
    1769             smart_ddl(self.db, 'create table "%s" ('
     831            query('drop table if exists "%s"' % table)
     832            query('create table "%s" ('
    1770833                "n integer, b boolean, d date, t text)" % table)
    1771             r = self.db.clear(table)
     834            r = clear(table)
    1772835            result = {'n': 0, 'b': 'f', 'd': '', 't': ''}
    1773836            self.assertEqual(r, result)
     
    1776839            r['b'] = 't'
    1777840            r['oid'] = 1L
    1778             r = self.db.clear(table, r)
     841            r = clear(table, r)
    1779842            result = {'a': 1, 'n': 0, 'b': 'f', 'd': '', 't': '', 'oid': 1L}
    1780843            self.assertEqual(r, result)
     844            query('drop table "%s"' % table)
    1781845
    1782846    def testDelete(self):
     847        delete = self.db.delete
     848        query = self.db.query
    1783849        for table in ('delete_test_table', 'test table for delete'):
    1784             smart_ddl(self.db, 'drop table "%s"' % table)
    1785             smart_ddl(self.db, 'create table "%s" ('
    1786                 "n integer, t text)" % table)
     850            query('drop table if exists "%s"' % table)
     851            query('create table "%s" ('
     852                "n integer, t text) with oids" % table)
    1787853            for n, t in enumerate('xyz'):
    1788                 self.db.query('insert into "%s" values('
     854                query('insert into "%s" values('
    1789855                    "%d, '%s')" % (table, n + 1, t))
    1790856            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
    1791857            r = self.db.get(table, 1, 'n')
    1792             s = self.db.delete(table, r)
     858            s = delete(table, r)
    1793859            self.assertEqual(s, 1)
    1794860            r = self.db.get(table, 3, 'n')
    1795             s = self.db.delete(table, r)
     861            s = delete(table, r)
    1796862            self.assertEqual(s, 1)
    1797             s = self.db.delete(table, r)
     863            s = delete(table, r)
    1798864            self.assertEqual(s, 0)
    1799             r = self.db.query('select * from "%s"' % table).dictresult()
     865            r = query('select * from "%s"' % table).dictresult()
    1800866            self.assertEqual(len(r), 1)
    1801867            r = r[0]
     
    1803869            self.assertEqual(r, result)
    1804870            r = self.db.get(table, 2, 'n')
    1805             s = self.db.delete(table, r)
     871            s = delete(table, r)
    1806872            self.assertEqual(s, 1)
    1807             s = self.db.delete(table, r)
     873            s = delete(table, r)
    1808874            self.assertEqual(s, 0)
    1809875            self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
     876            query('drop table "%s"' % table)
    1810877
    1811878    def testDeleteWithCompositeKey(self):
     879        query = self.db.query
    1812880        table = 'delete_test_table_1'
    1813         smart_ddl(self.db, "drop table %s" % table)
    1814         smart_ddl(self.db, "create table %s ("
     881        query("drop table if exists %s" % table)
     882        query("create table %s ("
    1815883            "n integer, t text, primary key (n))" % table)
    1816884        for n, t in enumerate('abc'):
    1817             self.db.query("insert into %s values("
     885            query("insert into %s values("
    1818886                "%d, '%s')" % (table, n + 1, t))
    1819887        self.assertRaises(pg.ProgrammingError, self.db.delete,
    1820888            table, dict(t='b'))
    1821889        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
    1822         r = self.db.query('select t from "%s" where n=2' % table
    1823             ).getresult()
     890        r = query('select t from "%s" where n=2' % table
     891                  ).getresult()
    1824892        self.assertEqual(r, [])
    1825893        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
    1826         r = self.db.query('select t from "%s" where n=3' % table
    1827             ).getresult()[0][0]
     894        r = query('select t from "%s" where n=3' % table
     895                  ).getresult()[0][0]
    1828896        self.assertEqual(r, 'c')
     897        query("drop table %s" % table)
    1829898        table = 'delete_test_table_2'
    1830         smart_ddl(self.db, "drop table %s" % table)
    1831         smart_ddl(self.db, "create table %s ("
     899        query("drop table if exists %s" % table)
     900        query("create table %s ("
    1832901            "n integer, m integer, t text, primary key (n, m))" % table)
    1833902        for n in range(3):
    1834903            for m in range(2):
    1835904                t = chr(ord('a') + 2 * n + m)
    1836                 self.db.query("insert into %s values("
     905                query("insert into %s values("
    1837906                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
    1838907        self.assertRaises(pg.ProgrammingError, self.db.delete,
    1839908            table, dict(n=2, t='b'))
    1840909        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
    1841         r = [r[0] for r in self.db.query('select t from "%s" where n=2'
     910        r = [r[0] for r in query('select t from "%s" where n=2'
    1842911            ' order by m' % table).getresult()]
    1843912        self.assertEqual(r, ['c'])
    1844913        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
    1845         r = [r[0] for r in self.db.query('select t from "%s" where n=3'
     914        r = [r[0] for r in query('select t from "%s" where n=3'
    1846915            ' order by m' % table).getresult()]
    1847916        self.assertEqual(r, ['e', 'f'])
    1848917        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
    1849         r = [r[0] for r in self.db.query('select t from "%s" where n=3'
     918        r = [r[0] for r in query('select t from "%s" where n=3'
    1850919            ' order by m' % table).getresult()]
    1851920        self.assertEqual(r, ['f'])
     921        query("drop table %s" % table)
    1852922
    1853923    def testTransaction(self):
    1854         smart_ddl(self.db, "drop table test_table")
    1855         self.db.query("create table test_table (n integer)")
     924        query = self.db.query
     925        query("drop table if exists test_table")
     926        query("create table test_table (n integer)")
    1856927        self.db.begin()
    1857         self.db.query("insert into test_table values (1)")
    1858         self.db.query("insert into test_table values (2)")
     928        query("insert into test_table values (1)")
     929        query("insert into test_table values (2)")
    1859930        self.db.commit()
    1860931        self.db.begin()
    1861         self.db.query("insert into test_table values (3)")
    1862         self.db.query("insert into test_table values (4)")
     932        query("insert into test_table values (3)")
     933        query("insert into test_table values (4)")
    1863934        self.db.rollback()
    1864935        self.db.begin()
    1865         self.db.query("insert into test_table values (5)")
     936        query("insert into test_table values (5)")
    1866937        self.db.savepoint('before6')
    1867         self.db.query("insert into test_table values (6)")
     938        query("insert into test_table values (6)")
    1868939        self.db.rollback('before6')
    1869         self.db.query("insert into test_table values (7)")
     940        query("insert into test_table values (7)")
    1870941        self.db.commit()
    1871942        self.db.begin()
    1872943        self.db.savepoint('before8')
    1873         self.db.query("insert into test_table values (8)")
     944        query("insert into test_table values (8)")
    1874945        self.db.release('before8')
    1875946        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
    1876947        self.db.commit()
    1877948        self.db.start()
    1878         self.db.query("insert into test_table values (9)")
     949        query("insert into test_table values (9)")
    1879950        self.db.end()
    1880         r = [r[0] for r in self.db.query(
     951        r = [r[0] for r in query(
    1881952            "select * from test_table order by 1").getresult()]
    1882953        self.assertEqual(r, [1, 2, 5, 7, 9])
     954        query("drop table test_table")
    1883955
    1884956    def testContextManager(self):
    1885         smart_ddl(self.db, "drop table test_table")
    1886         self.db.query("create table test_table (n integer check(n>0))")
     957        query = self.db.query
     958        query("drop table if exists test_table")
     959        query("create table test_table (n integer check(n>0))")
    1887960        with self.db:
    1888             self.db.query("insert into test_table values (1)")
    1889             self.db.query("insert into test_table values (2)")
     961            query("insert into test_table values (1)")
     962            query("insert into test_table values (2)")
    1890963        try:
    1891964            with self.db:
    1892                 self.db.query("insert into test_table values (3)")
    1893                 self.db.query("insert into test_table values (4)")
     965                query("insert into test_table values (3)")
     966                query("insert into test_table values (4)")
    1894967                raise ValueError('test transaction should rollback')
    1895968        except ValueError, error:
    1896969            self.assertEqual(str(error), 'test transaction should rollback')
    1897970        with self.db:
    1898             self.db.query("insert into test_table values (5)")
     971            query("insert into test_table values (5)")
    1899972        try:
    1900973            with self.db:
    1901                 self.db.query("insert into test_table values (6)")
    1902                 self.db.query("insert into test_table values (-1)")
     974                query("insert into test_table values (6)")
     975                query("insert into test_table values (-1)")
    1903976        except pg.ProgrammingError, error:
    1904977            self.assertTrue('check' in str(error))
    1905978        with self.db:
    1906             self.db.query("insert into test_table values (7)")
    1907         r = [r[0] for r in self.db.query(
     979            query("insert into test_table values (7)")
     980        r = [r[0] for r in query(
    1908981            "select * from test_table order by 1").getresult()]
    1909982        self.assertEqual(r, [1, 2, 5, 7])
     983        query("drop table test_table")
    1910984
    1911985    def testBytea(self):
    1912         smart_ddl(self.db, 'drop table bytea_test')
    1913         smart_ddl(self.db, 'create table bytea_test ('
     986        query = self.db.query
     987        query('drop table if exists bytea_test')
     988        query('create table bytea_test ('
    1914989            'data bytea)')
    1915990        s = "It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
    1916991        r = self.db.escape_bytea(s)
    1917         self.db.query('insert into bytea_test values('
     992        query('insert into bytea_test values('
    1918993            "'%s')" % r)
    1919         r = self.db.query('select * from bytea_test').getresult()
     994        r = query('select * from bytea_test').getresult()
    1920995        self.assertTrue(len(r) == 1)
    1921996        r = r[0]
     
    1924999        r = self.db.unescape_bytea(r)
    19251000        self.assertEqual(r, s)
     1001        query('drop table bytea_test')
    19261002
    19271003    def testDebugWithCallable(self):
    1928         self.assertTrue(self.db.debug is None)
     1004        if debug:
     1005            self.assertEqual(self.db.debug, debug)
     1006        else:
     1007            self.assertIsNone(self.db.debug)
    19291008        s = []
    19301009        self.db.debug = s.append
     
    19341013            self.assertEqual(s, ["select 1", "select 2"])
    19351014        finally:
    1936             self.db.debug = None
     1015            self.db.debug = debug
    19371016
    19381017
     
    19401019    """"Test correct handling of schemas (namespaces)."""
    19411020
    1942     # Test database needed: must be run as a DBTestSuite.
     1021    @classmethod
     1022    def setUpClass(cls):
     1023        db = DB()
     1024        query = db.query
     1025        query("set client_min_messages=warning")
     1026        for num_schema in range(5):
     1027            if num_schema:
     1028                schema = "s%d" % num_schema
     1029                query("drop schema if exists %s cascade" % (schema,))
     1030                try:
     1031                    query("create schema %s" % (schema,))
     1032                except pg.ProgrammingError:
     1033                    raise RuntimeError("The test user cannot create schemas.\n"
     1034                        "Grant create on database %s to the user"
     1035                        " for running these tests." % dbname)
     1036            else:
     1037                schema = "public"
     1038                query("drop table if exists %s.t" % (schema,))
     1039                query("drop table if exists %s.t%d" % (schema, num_schema))
     1040            query("create table %s.t with oids as select 1 as n, %d as d"
     1041                  % (schema, num_schema))
     1042            query("create table %s.t%d with oids as select 1 as n, %d as d"
     1043                  % (schema, num_schema, num_schema))
     1044        db.close()
     1045
     1046    @classmethod
     1047    def tearDownClass(cls):
     1048        db = DB()
     1049        query = db.query
     1050        query("set client_min_messages=warning")
     1051        for num_schema in range(5):
     1052            if num_schema:
     1053                schema = "s%d" % num_schema
     1054                query("drop schema %s cascade" % (schema,))
     1055            else:
     1056                schema = "public"
     1057                query("drop table %s.t" % (schema,))
     1058                query("drop table %s.t%d" % (schema, num_schema))
     1059        db.close()
    19431060
    19441061    def setUp(self):
    1945         dbname = DBTestSuite.dbname
    1946         self.dbname = dbname
    1947         self.db = pg.DB(dbname)
     1062        self.db = DB()
     1063        self.db.query("set client_min_messages=warning")
    19481064
    19491065    def tearDown(self):
     
    19581074                schema = "public"
    19591075            for t in (schema + ".t",
    1960                 schema + ".t" + str(num_schema)):
    1961                 self.assertTrue(t in tables, t + ' not in get_tables()')
     1076                    schema + ".t" + str(num_schema)):
     1077                self.assertIn(t, tables)
    19621078
    19631079    def testGetAttnames(self):
     1080        get_attnames = self.db.get_attnames
     1081        query = self.db.query
    19641082        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
    1965         r = self.db.get_attnames("t")
     1083        r = get_attnames("t")
    19661084        self.assertEqual(r, result)
    1967         r = self.db.get_attnames("s4.t4")
     1085        r = get_attnames("s4.t4")
    19681086        self.assertEqual(r, result)
    1969         smart_ddl(self.db, "create table s3.t3m"
    1970             " as select 1 as m")
     1087        query("drop table if exists s3.t3m")
     1088        query("create table s3.t3m with oids as select 1 as m")
    19711089        result_m = {'oid': 'int', 'm': 'int'}
    1972         r = self.db.get_attnames("s3.t3m")
     1090        r = get_attnames("s3.t3m")
    19731091        self.assertEqual(r, result_m)
    1974         self.db.query("set search_path to s1,s3")
    1975         r = self.db.get_attnames("t3")
     1092        query("set search_path to s1,s3")
     1093        r = get_attnames("t3")
    19761094        self.assertEqual(r, result)
    1977         r = self.db.get_attnames("t3m")
     1095        r = get_attnames("t3m")
    19781096        self.assertEqual(r, result_m)
     1097        query("drop table s3.t3m")
    19791098
    19801099    def testGet(self):
     1100        get = self.db.get
     1101        query = self.db.query
    19811102        PrgError = pg.ProgrammingError
    1982         self.assertEqual(self.db.get("t", 1, 'n')['d'], 0)
    1983         self.assertEqual(self.db.get("t0", 1, 'n')['d'], 0)
    1984         self.assertEqual(self.db.get("public.t", 1, 'n')['d'], 0)
    1985         self.assertEqual(self.db.get("public.t0", 1, 'n')['d'], 0)
    1986         self.assertRaises(PrgError, self.db.get, "public.t1", 1, 'n')
    1987         self.assertEqual(self.db.get("s1.t1", 1, 'n')['d'], 1)
    1988         self.assertEqual(self.db.get("s3.t", 1, 'n')['d'], 3)
    1989         self.db.query("set search_path to s2,s4")
    1990         self.assertRaises(PrgError, self.db.get, "t1", 1, 'n')
    1991         self.assertEqual(self.db.get("t4", 1, 'n')['d'], 4)
    1992         self.assertRaises(PrgError, self.db.get, "t3", 1, 'n')
    1993         self.assertEqual(self.db.get("t", 1, 'n')['d'], 2)
    1994         self.assertEqual(self.db.get("s3.t3", 1, 'n')['d'], 3)
    1995         self.db.query("set search_path to s1,s3")
    1996         self.assertRaises(PrgError, self.db.get, "t2", 1, 'n')
    1997         self.assertEqual(self.db.get("t3", 1, 'n')['d'], 3)
    1998         self.assertRaises(PrgError, self.db.get, "t4", 1, 'n')
    1999         self.assertEqual(self.db.get("t", 1, 'n')['d'], 1)
    2000         self.assertEqual(self.db.get("s4.t4", 1, 'n')['d'], 4)
     1103        self.assertEqual(get("t", 1, 'n')['d'], 0)
     1104        self.assertEqual(get("t0", 1, 'n')['d'], 0)
     1105        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
     1106        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
     1107        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
     1108        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
     1109        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
     1110        query("set search_path to s2,s4")
     1111        self.assertRaises(PrgError, get, "t1", 1, 'n')
     1112        self.assertEqual(get("t4", 1, 'n')['d'], 4)
     1113        self.assertRaises(PrgError, get, "t3", 1, 'n')
     1114        self.assertEqual(get("t", 1, 'n')['d'], 2)
     1115        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
     1116        query("set search_path to s1,s3")
     1117        self.assertRaises(PrgError, get, "t2", 1, 'n')
     1118        self.assertEqual(get("t3", 1, 'n')['d'], 3)
     1119        self.assertRaises(PrgError, get, "t4", 1, 'n')
     1120        self.assertEqual(get("t", 1, 'n')['d'], 1)
     1121        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
    20011122
    20021123    def testMangling(self):
    2003         r = self.db.get("t", 1, 'n')
    2004         self.assertTrue('oid(public.t)' in r)
    2005         self.db.query("set search_path to s2")
    2006         r = self.db.get("t2", 1, 'n')
    2007         self.assertTrue('oid(s2.t2)' in r)
    2008         self.db.query("set search_path to s3")
    2009         r = self.db.get("t", 1, 'n')
    2010         self.assertTrue('oid(s3.t)' in r)
    2011 
    2012 
    2013 class DBTestSuite(unittest.TestSuite):
    2014     """Test suite that provides a test database."""
    2015 
    2016     dbname = "testpg_tempdb"
    2017 
    2018     # It would be too slow to create and drop the test database for
    2019     # every single test, so it is done once for the whole suite only.
    2020 
    2021     def setUp(self):
    2022         dbname = self.dbname
    2023         c = pg.connect("template1")
    2024         try:
    2025             c.query("drop database " + dbname)
    2026         except pg.Error:
    2027             pass
    2028         c.query("create database " + dbname
    2029             + " template=template0 encoding='UTF8'")
    2030         for s in ('client_min_messages = warning',
    2031             'client_encoding = UTF8',
    2032             'date_style = ISO, MDY',
    2033             'lc_monetary = C',
    2034             'default_with_oids = on',
    2035             'standard_conforming_strings = off',
    2036             'escape_string_warning = off'):
    2037             smart_ddl(c, 'alter database %s set %s' % (dbname, s))
    2038         c.close()
    2039         c = pg.connect(dbname)
    2040         smart_ddl(c, "create table test ("
    2041             "i2 smallint, i4 integer, i8 bigint,"
    2042             "d numeric, f4 real, f8 double precision, m money, "
    2043             "v4 varchar(4), c4 char(4), t text)")
    2044         c.query("create view test_view as"
    2045             " select i4, v4 from test")
    2046         for num_schema in range(5):
    2047             if num_schema:
    2048                 schema = "s%d" % num_schema
    2049                 c.query("create schema " + schema)
    2050             else:
    2051                 schema = "public"
    2052             smart_ddl(c, "create table %s.t"
    2053                 " as select 1 as n, %d as d"
    2054                 % (schema, num_schema))
    2055             smart_ddl(c, "create table %s.t%d"
    2056                 " as select 1 as n, %d as d"
    2057                 % (schema, num_schema, num_schema))
    2058         c.close()
    2059 
    2060     def tearDown(self):
    2061         dbname = self.dbname
    2062         c = pg.connect(dbname)
    2063         c.query("checkpoint")
    2064         c.close()
    2065         c = pg.connect("template1")
    2066         c.query("drop database " + dbname)
    2067         c.close()
    2068 
    2069     def __call__(self, result):
    2070         self.setUp()
    2071         unittest.TestSuite.__call__(self, result)
    2072         self.tearDown()
     1124        get = self.db.get
     1125        query = self.db.query
     1126        r = get("t", 1, 'n')
     1127        self.assertIn('oid(public.t)', r)
     1128        query("set search_path to s2")
     1129        r = get("t2", 1, 'n')
     1130        self.assertIn('oid(s2.t2)', r)
     1131        query("set search_path to s3")
     1132        r = get("t", 1, 'n')
     1133        self.assertIn('oid(s3.t)', r)
    20731134
    20741135
    20751136if __name__ == '__main__':
    2076 
    2077     # All tests that do not need a database:
    2078     TestSuite1 = unittest.TestSuite((
    2079         unittest.makeSuite(TestAuxiliaryFunctions),
    2080         unittest.makeSuite(TestHasConnect),
    2081         unittest.makeSuite(TestEscapeFunctions),
    2082         unittest.makeSuite(TestCanConnect),
    2083         unittest.makeSuite(TestConnectObject),
    2084         unittest.makeSuite(TestDBClassBasic),
    2085         ))
    2086 
    2087     # All tests that need a test database:
    2088     TestSuite2 = DBTestSuite((
    2089         unittest.makeSuite(TestSimpleQueries),
    2090         unittest.makeSuite(TestParamQueries),
    2091         unittest.makeSuite(TestInserttable),
    2092         unittest.makeSuite(TestNoticeReceiver),
    2093         unittest.makeSuite(TestDBClass),
    2094         unittest.makeSuite(TestSchemas),
    2095         ))
    2096 
    2097     # All tests together in one test suite:
    2098     TestSuite = unittest.TestSuite((
    2099         TestSuite1,
    2100         TestSuite2,
    2101     ))
    2102 
    2103     opts = dict(verbosity=1)
    2104     if '-v' in sys.argv:
    2105         opts.update(verbosity=2)
    2106     if '-f' in sys.argv:
    2107         opts.update(failfast=True)  # needs Python 2.7
    2108 
    2109     rc = unittest.TextTestRunner(**opts).run(TestSuite)
    2110 
    2111     sys.exit(1 if rc.errors or rc.failures else 0)
     1137    unittest.main()
Note: See TracChangeset for help on using the changeset viewer.