Changeset 745 for trunk


Ignore:
Timestamp:
Jan 14, 2016, 8:16:23 PM (4 years ago)
Author:
cito
Message:

Add methods get/set_parameter to DB wrapper class

These methods can be used to get/set/reset run-time parameters,
even several at once.

Since this is pretty useful and will not break anything, I have
also back ported these additions to the 4.x branch.

Everything is well documented and tested, of course.

Location:
trunk
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • trunk/docs/contents/changelog.rst

    r730 r745  
    4343Version 4.2
    4444-----------
    45 - Set a better default for the user option "escaping-funcs".
    4645- The supported Python versions are 2.4 to 2.7.
    4746- PostgreSQL is supported in all versions from 8.3 to 9.5.
     47- Set a better default for the user option "escaping-funcs".
    4848- Force build to compile with no errors.
     49- New methods get_parameters() and set_parameters() in the classic interface
     50  which can be used to get or set run-time parameters.
    4951- Fix decimal point handling.
    5052- Add option to return boolean values as bool objects.
  • trunk/docs/contents/pg/db_wrapper.rst

    r739 r745  
    121121
    122122    :param str table: name of table
    123     :returns: A dictionary -- the keys are the attribute names,
    124      the values are the type names of the attributes.
     123    :returns: a dictionary mapping attribute names to type names
    125124
    126125Given the name of a table, digs out the set of attribute names.
     
    135134You can get the regular types after enabling this by calling the
    136135:meth:`DB.use_regtypes` method.
    137 
    138136
    139137has_table_privilege -- check table privilege
     
    152150
    153151.. versionadded:: 4.0
     152
     153get/set_parameter -- get or set  run-time parameters
     154----------------------------------------------------
     155
     156.. method:: DB.get_parameter(parameter)
     157
     158    Get the value of run-time parameters
     159
     160    :param parameter: the run-time parameter(s) to get
     161    :type param: str, tuple, list or dict
     162    :returns: the current value(s) of the run-time parameter(s)
     163    :rtype: str, list or dict
     164    :raises TypeError: Invalid parameter type(s)
     165    :raises ProgrammingError: Invalid parameter name(s)
     166
     167If the parameter is a string, the return value will also be a string
     168that is the current setting of the run-time parameter with that name.
     169
     170You can get several parameters at once by passing a list or tuple of
     171parameter names.  The return value will then be a corresponding list
     172of parameter settings.  If you pass a dict as parameter instead, its
     173values will be set to the parameter settings corresponding to its keys.
     174
     175By passing the special name `'all'` as the parameter, you can get a dict
     176of all existing configuration parameters.
     177
     178.. versionadded:: 4.2
     179
     180.. method:: DB.set_parameter(self, parameter, [value], [local])
     181
     182    Set the value of run-time parameters
     183
     184    :param parameter: the run-time parameter(s) to set
     185    :type param: string, tuple, list or dict
     186    :param value: the value to set
     187    :type param: str or None
     188    :raises TypeError: Invalid parameter type(s)
     189    :raises ValueError: Invalid value argument(s)
     190    :raises ProgrammingError: Invalid parameter name(s) or values
     191
     192If the parameter and the value are strings, the run-time parameter
     193will be set to that value.  If no value or *None* is passed as a value,
     194then the run-time parameter will be restored to its default value.
     195
     196You can set several parameters at once by passing a list or tuple
     197of parameter names, with a single value that all parameters should
     198be set to or with a corresponding list or tuple of values.
     199
     200You can also pass a dict as parameters.  In this case, you should
     201not pass a value, since the values will be taken from the dict.
     202
     203By passing the special name `'all'` as the parameter, you can reset
     204all existing settable run-time parameters to their default values.
     205
     206If you set *local* to `True`, then the command takes effect for only the
     207current transaction.  After :meth:`DB.commit` or :meth:`DB.rollback`,
     208the session-level setting takes effect again.  Setting *local* to `True`
     209will appear to have no effect if it is executed outside a transaction,
     210since the transaction will end immediately.
     211
     212.. versionadded:: 4.2
    154213
    155214begin/commit/rollback/savepoint/release -- transaction handling
  • trunk/pg.py

    r743 r745  
    462462        """Destroy a previously defined savepoint."""
    463463        return self.query('RELEASE ' + name)
     464
     465    def get_parameter(self, parameter):
     466        """Get the value of a run-time parameter.
     467
     468        If the parameter is a string, the return value will also be a string
     469        that is the current setting of the run-time parameter with that name.
     470
     471        You can get several parameters at once by passing a list or tuple of
     472        parameter names.  The return value will then be a corresponding list
     473        of parameter settings.  If you pass a dict as parameter instead, its
     474        values will be set to the parameter settings corresponding to its keys.
     475
     476        By passing the special name 'all' as the parameter, you can get a dict
     477        of all existing configuration parameters.
     478        """
     479        if isinstance(parameter, basestring):
     480            parameter = [parameter]
     481            values = None
     482        elif isinstance(parameter, (list, tuple)):
     483            values = []
     484        elif isinstance(parameter, dict):
     485            values = parameter
     486        else:
     487            raise TypeError('The parameter must be a dict, list or string')
     488        if not parameter:
     489            raise TypeError('No parameter has been specified')
     490        params = {} if isinstance(values, dict) else []
     491        for key in parameter:
     492            param = key.strip().lower() if isinstance(
     493                key, basestring) else None
     494            if not param:
     495                raise TypeError('Invalid parameter')
     496            if param == 'all':
     497                q = 'SHOW ALL'
     498                values = self.db.query(q).getresult()
     499                values = dict(value[:2] for value in values)
     500                break
     501            if isinstance(values, dict):
     502                params[param] = key
     503            else:
     504                params.append(param)
     505        else:
     506            for param in params:
     507                q = 'SHOW %s' % (param,)
     508                value = self.db.query(q).getresult()[0][0]
     509                if values is None:
     510                    values = value
     511                elif isinstance(values, list):
     512                    values.append(value)
     513                else:
     514                    values[params[param]] = value
     515        return values
     516
     517    def set_parameter(self, parameter, value=None, local=False):
     518        """Set the value of a run-time parameter.
     519
     520        If the parameter and the value are strings, the run-time parameter
     521        will be set to that value.  If no value or None is passed as a value,
     522        then the run-time parameter will be restored to its default value.
     523
     524        You can set several parameters at once by passing a list or tuple
     525        of parameter names, with a single value that all parameters should
     526        be set to or with a corresponding list or tuple of values.
     527
     528        You can also pass a dict as parameters.  In this case, you should
     529        not pass a value, since the values will be taken from the dict.
     530
     531        By passing the special name 'all' as the parameter, you can reset
     532        all existing settable run-time parameters to their default values.
     533
     534        If you set local to True, then the command takes effect for only the
     535        current transaction.  After commit() or rollback(), the session-level
     536        setting takes effect again.  Setting local to True will appear to
     537        have no effect if it is executed outside a transaction, since the
     538        transaction will end immediately.
     539        """
     540        if isinstance(parameter, basestring):
     541            parameter = {parameter: value}
     542        elif isinstance(parameter, (list, tuple)):
     543            if isinstance(value, (list, tuple)):
     544                parameter = dict(zip(parameter, value))
     545            else:
     546                parameter = dict.fromkeys(parameter, value)
     547        elif isinstance(parameter, dict):
     548            if value is not None:
     549                raise ValueError(
     550                    'A value must not be set when parameter is a dictionary')
     551        else:
     552            raise TypeError('The parameter must be a dict, list or string')
     553        if not parameter:
     554            raise TypeError('No parameter has been specified')
     555        params = {}
     556        for key, value in parameter.items():
     557            param = key.strip().lower() if isinstance(
     558                key, basestring) else None
     559            if not param:
     560                raise TypeError('Invalid parameter')
     561            if param == 'all':
     562                if value is not None:
     563                    raise ValueError(
     564                        "A value must ot be set when parameter is 'all'")
     565                params = {'all': None}
     566                break
     567            params[param] = value
     568        local = ' LOCAL' if local else ''
     569        for param, value in params.items():
     570            if value is None:
     571                q = 'RESET%s %s' % (local, param)
     572            else:
     573                q = 'SET%s %s TO %s' % (local, param, value)
     574            self._do_debug(q)
     575            self.db.query(q)
    464576
    465577    def query(self, qstr, *args):
  • trunk/tests/test_classic_dbwrapper.py

    r744 r745  
    8686        attributes = [
    8787            'begin',
    88             'cancel',
    89             'clear',
    90             'close',
    91             'commit',
    92             'db',
    93             'dbname',
    94             'debug',
    95             'delete',
    96             'end',
    97             'endcopy',
    98             'error',
    99             'escape_bytea',
    100             'escape_identifier',
    101             'escape_literal',
    102             'escape_string',
     88            'cancel', 'clear', 'close', 'commit',
     89            'db', 'dbname', 'debug', 'delete',
     90            'end', 'endcopy', 'error',
     91            'escape_bytea', 'escape_identifier',
     92            'escape_literal', 'escape_string',
    10393            'fileno',
    104             'get',
    105             'get_attnames',
    106             'get_databases',
    107             'get_notice_receiver',
    108             'get_relations',
    109             'get_tables',
    110             'getline',
    111             'getlo',
    112             'getnotify',
    113             'has_table_privilege',
    114             'host',
    115             'insert',
    116             'inserttable',
    117             'locreate',
    118             'loimport',
     94            'get', 'get_attnames', 'get_databases',
     95            'get_notice_receiver', 'get_parameter',
     96            'get_relations', 'get_tables',
     97            'getline', 'getlo', 'getnotify',
     98            'has_table_privilege', 'host',
     99            'insert', 'inserttable',
     100            'locreate', 'loimport',
    119101            'notification_handler',
    120102            'options',
    121             'parameter',
    122             'pkey',
    123             'port',
    124             'protocol_version',
    125             'putline',
     103            'parameter', 'pkey', 'port',
     104            'protocol_version', 'putline',
    126105            'query',
    127             'release',
    128             'reopen',
    129             'reset',
    130             'rollback',
    131             'savepoint',
    132             'server_version',
    133             'set_notice_receiver',
    134             'source',
    135             'start',
    136             'status',
     106            'release', 'reopen', 'reset', 'rollback',
     107            'savepoint', 'server_version',
     108            'set_notice_receiver', 'set_parameter',
     109            'source', 'start', 'status',
    137110            'transaction',
    138             'unescape_bytea',
    139             'update',
    140             'upsert',
    141             'use_regtypes',
    142             'user',
     111            'unescape_bytea', 'update', 'upsert',
     112            'use_regtypes', 'user',
    143113        ]
    144114        db_attributes = [a for a in dir(self.db)
     
    288258        db.query("create table test ("
    289259            "i2 smallint, i4 integer, i8 bigint,"
    290             "d numeric, f4 real, f8 double precision, m money, "
    291             "v4 varchar(4), c4 char(4), t text)")
     260            " d numeric, f4 real, f8 double precision, m money,"
     261            " v4 varchar(4), c4 char(4), t text)")
    292262        db.query("create or replace view test_view as"
    293263            " select i4, v4 from test")
     
    413383        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
    414384
    415     def testGetAttnames(self):
    416         get_attnames = self.db.get_attnames
    417         query = self.db.query
    418         query("drop table if exists test_table")
    419         self.addCleanup(query, "drop table test_table")
    420         query("create table test_table("
    421             " n int, alpha smallint, beta bool,"
    422             " gamma char(5), tau text, v varchar(3))")
    423         r = get_attnames("test_table")
     385    def testGetParameter(self):
     386        f = self.db.get_parameter
     387        self.assertRaises(TypeError, f)
     388        self.assertRaises(TypeError, f, None)
     389        self.assertRaises(TypeError, f, 42)
     390        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
     391        r = f('standard_conforming_strings')
     392        self.assertEqual(r, 'on')
     393        r = f('lc_monetary')
     394        self.assertEqual(r, 'C')
     395        r = f('datestyle')
     396        self.assertEqual(r, 'ISO, YMD')
     397        r = f('bytea_output')
     398        self.assertEqual(r, 'hex')
     399        r = f(('bytea_output', 'lc_monetary'))
     400        self.assertIsInstance(r, list)
     401        self.assertEqual(r, ['hex', 'C'])
     402        r = f(['standard_conforming_strings', 'datestyle', 'bytea_output'])
     403        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
     404        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
     405        r = f(s)
     406        self.assertIs(r, s)
     407        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
     408        s = dict.fromkeys(('Bytea_Output', 'LC_Monetary'))
     409        r = f(s)
     410        self.assertIs(r, s)
     411        self.assertEqual(r, {'Bytea_Output': 'hex', 'LC_Monetary': 'C'})
     412
     413    def testGetParameterServerVersion(self):
     414        r = self.db.get_parameter('server_version_num')
     415        self.assertIsInstance(r, str)
     416        s = self.db.server_version
     417        self.assertIsInstance(s, int)
     418        self.assertEqual(r, str(s))
     419
     420    def testGetParameterAll(self):
     421        f = self.db.get_parameter
     422        r = f('all')
    424423        self.assertIsInstance(r, dict)
    425         self.assertEquals(r, dict(
    426             n='int', alpha='int', beta='bool',
    427             gamma='text', tau='text', v='text'))
    428 
    429     def testGetAttnamesWithQuotes(self):
    430         get_attnames = self.db.get_attnames
    431         query = self.db.query
    432         table = 'test table for get_attnames()'
    433         query('drop table if exists "%s"' % table)
    434         self.addCleanup(query, 'drop table "%s"' % table)
    435         query('create table "%s"('
    436             '"Prime!" smallint,'
    437             '"much space" integer, "Questions?" text)' % table)
    438         r = get_attnames(table)
    439         self.assertIsInstance(r, dict)
    440         self.assertEquals(r, {
    441             'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
    442 
    443     def testGetAttnamesWithRegtypes(self):
    444         get_attnames = self.db.get_attnames
    445         query = self.db.query
    446         query("drop table if exists test_table")
    447         self.addCleanup(query, "drop table test_table")
    448         query("create table test_table("
    449             " n int, alpha smallint, beta bool,"
    450             " gamma char(5), tau text, v varchar(3))")
    451         self.db.use_regtypes(True)
    452         try:
    453             r = get_attnames("test_table")
    454             self.assertIsInstance(r, dict)
    455         finally:
    456             self.db.use_regtypes(False)
    457         self.assertEquals(r, dict(
    458             n='integer', alpha='smallint', beta='boolean',
    459             gamma='character', tau='text', v='character varying'))
    460 
    461     def testGetAttnamesIsCached(self):
    462         get_attnames = self.db.get_attnames
    463         query = self.db.query
    464         query("drop table if exists test_table")
    465         self.addCleanup(query, "drop table if exists test_table")
    466         query("create table test_table(col int)")
    467         r = get_attnames("test_table")
    468         self.assertIsInstance(r, dict)
    469         self.assertEquals(r, dict(col='int'))
    470         query("drop table test_table")
    471         query("create table test_table(col text)")
    472         r = get_attnames("test_table")
    473         self.assertEquals(r, dict(col='int'))
    474         r = get_attnames("test_table", flush=True)
    475         self.assertEquals(r, dict(col='text'))
    476         query("drop table test_table")
    477         r = get_attnames("test_table")
    478         self.assertEquals(r, dict(col='text'))
    479         self.assertRaises(pg.ProgrammingError,
    480             get_attnames, "test_table", flush=True)
    481 
    482     def testGetAttnamesIsOrdered(self):
    483         get_attnames = self.db.get_attnames
    484         query = self.db.query
    485         query("drop table if exists test_table")
    486         self.addCleanup(query, "drop table test_table")
    487         query("create table test_table("
    488             " n int, alpha smallint, v varchar(3),"
    489             " gamma char(5), tau text, beta bool)")
    490         r = get_attnames("test_table")
    491         self.assertIsInstance(r, OrderedDict)
    492         self.assertEquals(r, OrderedDict([
    493             ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
    494             ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
    495         if OrderedDict is dict:
    496             self.skipTest('OrderedDict is not supported')
    497         r = ' '.join(list(r.keys()))
    498         self.assertEquals(r, 'n alpha v gamma tau beta')
     424        self.assertEqual(r['standard_conforming_strings'], 'on')
     425        self.assertEqual(r['lc_monetary'], 'C')
     426        self.assertEqual(r['DateStyle'], 'ISO, YMD')
     427        self.assertEqual(r['bytea_output'], 'hex')
     428
     429    def testSetParameter(self):
     430        f = self.db.set_parameter
     431        g = self.db.get_parameter
     432        self.assertRaises(TypeError, f)
     433        self.assertRaises(TypeError, f, None)
     434        self.assertRaises(TypeError, f, 42)
     435        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
     436        f('standard_conforming_strings', 'off')
     437        self.assertEqual(g('standard_conforming_strings'), 'off')
     438        f('datestyle', 'ISO, DMY')
     439        self.assertEqual(g('datestyle'), 'ISO, DMY')
     440        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
     441        self.assertEqual(g('standard_conforming_strings'), 'on')
     442        self.assertEqual(g('datestyle'), 'ISO, YMD')
     443        f(['standard_conforming_strings', 'datestyle'], ['off', 'ISO, DMY'])
     444        self.assertEqual(g('standard_conforming_strings'), 'off')
     445        self.assertEqual(g('datestyle'), 'ISO, DMY')
     446        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
     447        self.assertEqual(g('standard_conforming_strings'), 'on')
     448        self.assertEqual(g('datestyle'), 'ISO, YMD')
     449        f(('default_with_oids', 'standard_conforming_strings'), 'off')
     450        self.assertEqual(g('default_with_oids'), 'off')
     451        self.assertEqual(g('standard_conforming_strings'), 'off')
     452        f(['default_with_oids', 'standard_conforming_strings'], 'on')
     453        self.assertEqual(g('default_with_oids'), 'on')
     454        self.assertEqual(g('standard_conforming_strings'), 'on')
     455
     456    def testResetParameter(self):
     457        db = DB()
     458        f = db.set_parameter
     459        g = db.get_parameter
     460        r = g('default_with_oids')
     461        self.assertIn(r, ('on', 'off'))
     462        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
     463        r = g('standard_conforming_strings')
     464        self.assertIn(r, ('on', 'off'))
     465        scs, not_scs = r, 'off' if r == 'on' else 'on'
     466        f('default_with_oids', not_dwi)
     467        f('standard_conforming_strings', not_scs)
     468        self.assertEqual(g('default_with_oids'), not_dwi)
     469        self.assertEqual(g('standard_conforming_strings'), not_scs)
     470        f('default_with_oids')
     471        f('standard_conforming_strings', None)
     472        self.assertEqual(g('default_with_oids'), dwi)
     473        self.assertEqual(g('standard_conforming_strings'), scs)
     474        f('default_with_oids', not_dwi)
     475        f('standard_conforming_strings', not_scs)
     476        self.assertEqual(g('default_with_oids'), not_dwi)
     477        self.assertEqual(g('standard_conforming_strings'), not_scs)
     478        f(('default_with_oids', 'standard_conforming_strings'))
     479        self.assertEqual(g('default_with_oids'), dwi)
     480        self.assertEqual(g('standard_conforming_strings'), scs)
     481        f('default_with_oids', not_dwi)
     482        f('standard_conforming_strings', not_scs)
     483        self.assertEqual(g('default_with_oids'), not_dwi)
     484        self.assertEqual(g('standard_conforming_strings'), not_scs)
     485        f(['default_with_oids', 'standard_conforming_strings'], None)
     486        self.assertEqual(g('default_with_oids'), dwi)
     487        self.assertEqual(g('standard_conforming_strings'), scs)
     488
     489    def testResetParameterAll(self):
     490        db = DB()
     491        f = db.set_parameter
     492        self.assertRaises(ValueError, f, 'all', 0)
     493        self.assertRaises(ValueError, f, 'all', 'off')
     494        g = db.get_parameter
     495        r = g('default_with_oids')
     496        self.assertIn(r, ('on', 'off'))
     497        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
     498        r = g('standard_conforming_strings')
     499        self.assertIn(r, ('on', 'off'))
     500        scs, not_scs = r, 'off' if r == 'on' else 'on'
     501        f('default_with_oids', not_dwi)
     502        f('standard_conforming_strings', not_scs)
     503        self.assertEqual(g('default_with_oids'), not_dwi)
     504        self.assertEqual(g('standard_conforming_strings'), not_scs)
     505        f('all')
     506        self.assertEqual(g('default_with_oids'), dwi)
     507        self.assertEqual(g('standard_conforming_strings'), scs)
     508
     509    def testSetParameterLocal(self):
     510        f = self.db.set_parameter
     511        g = self.db.get_parameter
     512        self.assertEqual(g('standard_conforming_strings'), 'on')
     513        self.db.begin()
     514        f('standard_conforming_strings', 'off', local=True)
     515        self.assertEqual(g('standard_conforming_strings'), 'off')
     516        self.db.end()
     517        self.assertEqual(g('standard_conforming_strings'), 'on')
     518
     519    def testSetParameterSession(self):
     520        f = self.db.set_parameter
     521        g = self.db.get_parameter
     522        self.assertEqual(g('standard_conforming_strings'), 'on')
     523        self.db.begin()
     524        f('standard_conforming_strings', 'off', local=False)
     525        self.assertEqual(g('standard_conforming_strings'), 'off')
     526        self.db.end()
     527        self.assertEqual(g('standard_conforming_strings'), 'off')
    499528
    500529    def testQuery(self):
     
    589618                "c smallint, d smallint primary key)" % t)
    590619            query('create table "%s3" ('
    591                 "e smallint, f smallint, g smallint, "
    592                 "h smallint, i smallint, "
    593                 "primary key (f, h))" % t)
     620                "e smallint, f smallint, g smallint,"
     621                " h smallint, i smallint,"
     622                " primary key (f, h))" % t)
    594623            query('create table "%s4" ('
    595624                "more_than_one_letter varchar primary key)" % t)
     
    597626                '"with space" date primary key)' % t)
    598627            query('create table "%s6" ('
    599                 'a_very_long_column_name varchar, '
    600                 '"with space" date, '
    601                 '"42" int, '
    602                 "primary key (a_very_long_column_name, "
    603                 '"with space", "42"))' % t)
     628                'a_very_long_column_name varchar,'
     629                ' "with space" date,'
     630                ' "42" int,'
     631                " primary key (a_very_long_column_name,"
     632                ' "with space", "42"))' % t)
    604633            self.assertRaises(KeyError, pkey, '%s0' % t)
    605634            self.assertEqual(pkey('%s1' % t), 'b')
     
    687716        self.assertNotIn('public.test_view', result)
    688717
    689     def testAttnames(self):
     718    def testGetAttnames(self):
     719        get_attnames = self.db.get_attnames
    690720        self.assertRaises(pg.ProgrammingError,
    691721            self.db.get_attnames, 'does_not_exist')
    692722        self.assertRaises(pg.ProgrammingError,
    693723            self.db.get_attnames, 'has.too.many.dots')
    694         for table in ('attnames_test_table', 'test table for attnames'):
    695             self.db.query('drop table if exists "%s"' % table)
    696             self.addCleanup(self.db.query, 'drop table "%s"' % table)
    697             self.db.query('create table "%s" ('
    698                 'a smallint, b integer, c bigint, '
    699                 'e numeric, f float, f2 double precision, m money, '
    700                 'x smallint, y smallint, z smallint, '
    701                 'Normal_NaMe smallint, "Special Name" smallint, '
    702                 't text, u char(2), v varchar(2), '
    703                 'primary key (y, u)) with oids' % table)
    704             attributes = self.db.get_attnames(table)
    705             result = {'a': 'int', 'c': 'int', 'b': 'int',
    706                 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
    707                 'normal_name': 'int', 'Special Name': 'int',
    708                 'u': 'text', 't': 'text', 'v': 'text',
    709                 'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
    710             self.assertEqual(attributes, result)
     724        query = self.db.query
     725        query("drop table if exists test_table")
     726        self.addCleanup(query, "drop table test_table")
     727        query("create table test_table("
     728            " n int, alpha smallint, beta bool,"
     729            " gamma char(5), tau text, v varchar(3))")
     730        r = get_attnames('test_table')
     731        self.assertIsInstance(r, dict)
     732        self.assertEqual(r, dict(
     733            n='int', alpha='int', beta='bool',
     734            gamma='text', tau='text', v='text'))
     735
     736    def testGetAttnamesWithQuotes(self):
     737        get_attnames = self.db.get_attnames
     738        query = self.db.query
     739        table = 'test table for get_attnames()'
     740        query('drop table if exists "%s"' % table)
     741        self.addCleanup(query, 'drop table "%s"' % table)
     742        query('create table "%s"('
     743            '"Prime!" smallint,'
     744            ' "much space" integer, "Questions?" text)' % table)
     745        r = get_attnames(table)
     746        self.assertIsInstance(r, dict)
     747        self.assertEqual(r, {
     748            'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
     749        table = 'yet another test table for get_attnames()'
     750        query('drop table if exists "%s"' % table)
     751        self.addCleanup(query, 'drop table "%s"' % table)
     752        self.db.query('create table "%s" ('
     753            'a smallint, b integer, c bigint,'
     754            ' e numeric, f float, f2 double precision, m money,'
     755            ' x smallint, y smallint, z smallint,'
     756            ' Normal_NaMe smallint, "Special Name" smallint,'
     757            ' t text, u char(2), v varchar(2),'
     758            ' primary key (y, u)) with oids' % table)
     759        r = get_attnames(table)
     760        self.assertIsInstance(r, dict)
     761        self.assertEqual(r, {'a': 'int', 'c': 'int', 'b': 'int',
     762            'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
     763            'normal_name': 'int', 'Special Name': 'int',
     764            'u': 'text', 't': 'text', 'v': 'text',
     765            'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
     766
     767    def testGetAttnamesWithRegtypes(self):
     768        get_attnames = self.db.get_attnames
     769        query = self.db.query
     770        query("drop table if exists test_table")
     771        self.addCleanup(query, "drop table test_table")
     772        query("create table test_table("
     773            " n int, alpha smallint, beta bool,"
     774            " gamma char(5), tau text, v varchar(3))")
     775        self.db.use_regtypes(True)
     776        try:
     777            r = get_attnames("test_table")
     778            self.assertIsInstance(r, dict)
     779        finally:
     780            self.db.use_regtypes(False)
     781        self.assertEqual(r, dict(
     782            n='integer', alpha='smallint', beta='boolean',
     783            gamma='character', tau='text', v='character varying'))
     784
     785    def testGetAttnamesIsCached(self):
     786        get_attnames = self.db.get_attnames
     787        query = self.db.query
     788        query("drop table if exists test_table")
     789        self.addCleanup(query, "drop table if exists test_table")
     790        query("create table test_table(col int)")
     791        r = get_attnames("test_table")
     792        self.assertIsInstance(r, dict)
     793        self.assertEqual(r, dict(col='int'))
     794        query("drop table test_table")
     795        query("create table test_table(col text)")
     796        r = get_attnames("test_table")
     797        self.assertEqual(r, dict(col='int'))
     798        r = get_attnames("test_table", flush=True)
     799        self.assertEqual(r, dict(col='text'))
     800        query("drop table test_table")
     801        r = get_attnames("test_table")
     802        self.assertEqual(r, dict(col='text'))
     803        self.assertRaises(pg.ProgrammingError,
     804            get_attnames, "test_table", flush=True)
     805
     806    def testGetAttnamesIsOrdered(self):
     807        get_attnames = self.db.get_attnames
     808        query = self.db.query
     809        query("drop table if exists test_table")
     810        self.addCleanup(query, "drop table test_table")
     811        query("create table test_table("
     812            " n int, alpha smallint, v varchar(3),"
     813            " gamma char(5), tau text, beta bool)")
     814        r = get_attnames("test_table")
     815        self.assertIsInstance(r, OrderedDict)
     816        self.assertEqual(r, OrderedDict([
     817            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
     818            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
     819        if OrderedDict is dict:
     820            self.skipTest('OrderedDict is not supported')
     821        r = ' '.join(list(r.keys()))
     822        self.assertEqual(r, 'n alpha v gamma tau beta')
    711823
    712824    def testHasTablePrivilege(self):
     
    801913        query('create table "%s" ('
    802914            '"Prime!" smallint primary key,'
    803             '"much space" integer, "Questions?" text)' % table)
     915            ' "much space" integer, "Questions?" text)' % table)
    804916        query('insert into "%s"'
    805917              " values(17, 1001, 'No!')" % table)
     
    9691081        query('create table "%s" ('
    9701082            '"Prime!" smallint primary key,'
    971             '"much space" integer, "Questions?" text)' % table)
     1083            ' "much space" integer, "Questions?" text)' % table)
    9721084        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
    9731085        r = insert(table, r)
     
    10601172        query('create table "%s" ('
    10611173            '"Prime!" smallint primary key,'
    1062             '"much space" integer, "Questions?" text)' % table)
     1174            ' "much space" integer, "Questions?" text)' % table)
    10631175        query('insert into "%s"'
    10641176              " values(13, 3003, 'Why!')" % table)
     
    12221334        query('create table "%s" ('
    12231335            '"Prime!" smallint primary key,'
    1224             '"much space" integer, "Questions?" text)' % table)
     1336            ' "much space" integer, "Questions?" text)' % table)
    12251337        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
    12261338        try:
     
    12751387        query('create table "%s" ('
    12761388            '"Prime!" smallint primary key,'
    1277             '"much space" integer, "Questions?" text)' % table)
     1389            ' "much space" integer, "Questions?" text)' % table)
    12781390        r = clear(table)
    12791391        self.assertIsInstance(r, dict)
     
    13671479        query('create table "%s" ('
    13681480            '"Prime!" smallint primary key,'
    1369             '"much space" integer, "Questions?" text)' % table)
     1481            ' "much space" integer, "Questions?" text)' % table)
    13701482        query('insert into "%s"'
    13711483              " values(19, 5005, 'Yes!')" % table)
Note: See TracChangeset for help on using the changeset viewer.