Changeset 797


Ignore:
Timestamp:
Jan 29, 2016, 5:43:22 PM (4 years ago)
Author:
cito
Message:

Cache typecast functions and make them configurable

The typecast functions used by the pgdb module are now cached
using a local and a global Typecasts class. The local cache is
bound to the connection and knows how to cast composite types.

Also added functions that allow registering custom typecast
functions on the global and local level.

Also added a chapter on type adaptation and casting to the docs.

Location:
trunk
Files:
1 added
6 edited
1 copied

Legend:

Unmodified
Added
Removed
  • trunk/docs/contents/pgdb/connection.rst

    r796 r797  
    8484    A dictionary with the various type codes for the PostgreSQL types
    8585
    86 You can request the dictionary either via a PostgreSQL type name (which
    87 (is equal to the DB-API 2 *type_code*) or via a PostgreSQL type OIDs.
    88 
    89 The values are *type_code* strings carrying additional attributes:
    90 
    91         - *oid* -- the OID of the type
    92         - *len*  -- the internal size
    93         - *type*  -- ``'b'`` = base, ``'c'`` = composite, ...
    94         - *category*  -- ``'A'`` = Array, ``'B'`` = Boolean, ...
    95         - *delim*  -- delimiter to be used when parsing arrays
    96         - *relid*  -- the table OID for composite types
    97 
    98 For details, see the PostgreSQL documentation on `pg_type
    99 <http://www.postgresql.org/docs/current/static/catalog-pg-type.html>`_.
    100 
    101 The :attr:`Connection.type_cache` also provides a method :meth:`columns`
    102 that returns the names and type OIDs of the columns of composite types.
     86This can be used for getting more information on the PostgreSQL database
     87types or changing the typecast functions used for the connection.  See the
     88description of the :class:`TypeCache` class for details.
    10389
    10490.. versionadded:: 5.0
  • trunk/docs/contents/pgdb/index.rst

    r710 r797  
    1414    cursor
    1515    types
     16    typecache
     17    adaptation
  • trunk/docs/contents/pgdb/module.rst

    r751 r797  
    3636
    3737    con = connect(dsn='myhost:mydb', user='guido', password='234$')
     38
     39
     40get/set/reset_typecast -- Control the global typecast functions
     41---------------------------------------------------------------
     42
     43PyGreSQL uses typecast functions to cast the raw data coming from the
     44database to Python objects suitable for the particular database type.
     45These functions take a single string argument that represents the data
     46to be casted and must return the casted value.
     47
     48PyGreSQL provides built-in typecast functions for the common database types,
     49but if you want to change these or add more typecast functions, you can use
     50the following functions.
     51
     52.. note::
     53
     54    The following functions are not part of the DB-API 2 standard.
     55
     56.. method:: get_typecast(typ)
     57
     58    Get the global cast function for the given database type
     59
     60    :param str typ: PostgreSQL type name or type code
     61    :returns: the typecast function for the specified type
     62    :rtype: function or None
     63
     64.. versionadded:: 5.0
     65
     66.. method:: set_typecast(typ, cast)
     67
     68    Set a global typecast function for the given database type(s)
     69
     70    :param typ: PostgreSQL type name or type code, or list of such
     71    :type typ: str or list
     72    :param cast: the typecast function to be set for the specified type(s)
     73    :type typ: str or int
     74
     75.. versionadded:: 5.0
     76
     77.. method:: reset_typecast([typ])
     78
     79    Reset the typecasts for the specified (or all) type(s) to their defaults
     80
     81    :param str typ: PostgreSQL type name or type code, or list of such,
     82        or None to reset all typecast functions
     83    :type typ: str, list or None
     84
     85.. versionadded:: 5.0
     86
     87Note that database connections cache types and their cast functions using
     88connection specific :class:`TypeCache` objects.  You can also get, set and
     89reset typecast functions on the connection level using the methods
     90:meth:`TypeCache.get_typecast`, :meth:`TypeCache.set_typecast` and
     91:meth:`TypeCache.reset_typecast` of the :attr:`Connection.type_cache`.  This
     92will not affect other connections or future connections. In order to be sure
     93a global change is picked up by a running connection, you must reopen it or
     94call :meth:`TypeCache.reset_typecast` on the :attr:`Connection.type_cache`.
    3895
    3996
  • trunk/docs/contents/pgdb/typecache.rst

    r786 r797  
    1 Cursor -- The cursor object
    2 ===========================
     1TypeCache -- The internal cache for database types
     2==================================================
    33
    44.. py:currentmodule:: pgdb
    55
    6 .. class:: Cursor
    7 
    8 These objects represent a database cursor, which is used to manage the context
    9 of a fetch operation. Cursors created from the same connection are not
    10 isolated, i.e., any changes done to the database by a cursor are immediately
    11 visible by the other cursors. Cursors created from different connections can
    12 or can not be isolated, depending on the level of transaction isolation.
    13 The default PostgreSQL transaction isolation level is "read committed".
    14 
    15 Cursor objects respond to the following methods and attributes.
    16 
    17 Note that ``Cursor`` objects also implement both the iterator and the
    18 context manager protocol, i.e. you can iterate over them and you can use them
    19 in a ``with`` statement.
    20 
    21 description -- details regarding the result columns
    22 ---------------------------------------------------
    23 
    24 .. attribute:: Cursor.description
    25 
    26     This read-only attribute is a sequence of 7-item named tuples.
    27 
    28     Each of these named tuples contains information describing
    29     one result column:
    30 
    31         - *name*
    32         - *type_code*
    33         - *display_size*
    34         - *internal_size*
    35         - *precision*
    36         - *scale*
    37         - *null_ok*
    38 
    39     The values for *precision* and *scale* are only set for numeric types.
    40     The values for *display_size* and *null_ok* are always ``None``.
    41 
    42     This attribute will be ``None`` for operations that do not return rows
    43     or if the cursor has not had an operation invoked via the
    44     :meth:`Cursor.execute` or :meth:`Cursor.executemany` method yet.
    45 
    46 .. versionchanged:: 5.0
    47     Before version 5.0, this attribute was an ordinary tuple.
    48 
    49 rowcount -- number of rows of the result
    50 ----------------------------------------
    51 
    52 .. attribute:: Cursor.rowcount
    53 
    54     This read-only attribute specifies the number of rows that the last
    55     :meth:`Cursor.execute` or :meth:`Cursor.executemany` call produced
    56     (for DQL statements like SELECT) or affected (for DML statements like
    57     UPDATE or INSERT). It is also set by the :meth:`Cursor.copy_from` and
    58     :meth':`Cursor.copy_to` methods. The attribute is -1 in case no such
    59     method call has been performed on the cursor or the rowcount of the
    60     last operation cannot be determined by the interface.
    61 
    62 close -- close the cursor
    63 -------------------------
    64 
    65 .. method:: Cursor.close()
    66 
    67     Close the cursor now (rather than whenever it is deleted)
    68 
    69     :rtype: None
    70 
    71 The cursor will be unusable from this point forward; an :exc:`Error`
    72 (or subclass) exception will be raised if any operation is attempted
    73 with the cursor.
    74 
    75 execute -- execute a database operation
    76 ---------------------------------------
    77 
    78 .. method:: Cursor.execute(operation, [parameters])
    79 
    80     Prepare and execute a database operation (query or command)
    81 
    82     :param str operation: the database operation
    83     :param parameters: a sequence or mapping of parameters
    84     :returns: the cursor, so you can chain commands
    85 
    86 Parameters may be provided as sequence or mapping and will be bound to
    87 variables in the operation. Variables are specified using Python extended
    88 format codes, e.g. ``" ... WHERE name=%(name)s"``.
    89 
    90 A reference to the operation will be retained by the cursor. If the same
    91 operation object is passed in again, then the cursor can optimize its behavior.
    92 This is most effective for algorithms where the same operation is used,
    93 but different parameters are bound to it (many times).
    94 
    95 The parameters may also be specified as list of tuples to e.g. insert multiple
    96 rows in a single operation, but this kind of usage is deprecated:
    97 :meth:`Cursor.executemany` should be used instead.
    98 
    99 Note that in case this method raises a :exc:`DatabaseError`, you can get
    100 information about the error condition that has occurred by introspecting
    101 its :attr:`DatabaseError.sqlstate` attribute, which will be the ``SQLSTATE``
    102 error code associated with the error.  Applications that need to know which
    103 error condition has occurred should usually test the error code, rather than
    104 looking at the textual error message.
    105 
    106 executemany -- execute many similar database operations
    107 -------------------------------------------------------
    108 
    109 .. method:: Cursor.executemany(operation, [seq_of_parameters])
    110 
    111     Prepare and execute many similar database operations (queries or commands)
    112 
    113     :param str operation: the database operation
    114     :param seq_of_parameters: a sequence or mapping of parameter tuples or mappings
    115     :returns: the cursor, so you can chain commands
    116 
    117 Prepare a database operation (query or command) and then execute it against
    118 all parameter tuples or mappings found in the sequence *seq_of_parameters*.
    119 
    120 Parameters are bounded to the query using Python extended format codes,
    121 e.g. ``" ... WHERE name=%(name)s"``.
    122 
    123 callproc -- Call a stored procedure
    124 -----------------------------------
    125 
    126 .. method:: Cursor.callproc(self, procname, [parameters]):
    127 
    128     Call a stored database procedure with the given name
    129 
    130     :param str procname: the name of the database function
    131     :param parameters: a sequence of parameters (can be empty or omitted)
    132 
    133 This method calls a stored procedure (function) in the PostgreSQL database.
    134 
    135 The sequence of parameters must contain one entry for each input argument
    136 that the function expects. The result of the call is the same as this input
    137 sequence; replacement of output and input/output parameters in the return
    138 value is currently not supported.
    139 
    140 The function may also provide a result set as output. These can be requested
    141 through the standard fetch methods of the cursor.
     6.. class:: TypeCache
    1427
    1438.. versionadded:: 5.0
    1449
    145 fetchone -- fetch next row of the query result
    146 ----------------------------------------------
     10The internal :class:`TypeCache` of PyGreSQL is not part of the DB-API 2
     11standard, but is documented here in case you need full control and
     12understanding of the internal handling of database types.
    14713
    148 .. method:: Cursor.fetchone()
     14The TypeCache is essentially a dictionary mapping PostgreSQL internal
     15type names and type OIDs to DB-API 2 "type codes" (which are also returned
     16as the *type_code* field of the :attr:`Cursor.description` attribute).
    14917
    150     Fetch the next row of a query result set
     18These type codes are strings which are equal to the PostgreSQL internal
     19type name, but they are also carrying additional information about the
     20associated PostgreSQL type in the following attributes:
    15121
    152     :returns: the next row of the query result set
    153     :rtype: named tuple or None
     22        - *oid* -- the OID of the type
     23        - *len*  -- the internal size
     24        - *type*  -- ``'b'`` = base, ``'c'`` = composite, ...
     25        - *category*  -- ``'A'`` = Array, ``'B'`` = Boolean, ...
     26        - *delim*  -- delimiter to be used when parsing arrays
     27        - *relid*  -- the table OID for composite types
    15428
    155 Fetch the next row of a query result set, returning a single named tuple,
    156 or ``None`` when no more data is available. The field names of the named
    157 tuple are the same as the column names of the database query as long as
    158 they are valid Python identifiers.
     29For details, see the PostgreSQL documentation on `pg_type
     30<http://www.postgresql.org/docs/current/static/catalog-pg-type.html>`_.
    15931
    160 An :exc:`Error` (or subclass) exception is raised if the previous call to
    161 :meth:`Cursor.execute` or :meth:`Cursor.executemany` did not produce
    162 any result set or no call was issued yet.
     32In addition to the dictionary methods, the :class:`TypeCache` provides
     33the following methods:
    16334
    164 .. versionchanged:: 5.0
    165     Before version 5.0, this method returned ordinary tuples.
     35.. method:: TypeCache.get_fields(typ)
    16636
    167 fetchmany -- fetch next set of rows of the query result
    168 -------------------------------------------------------
     37    Get the names and types of the fields of composite types
    16938
    170 .. method:: Cursor.fetchmany([size=None], [keep=False])
     39    :param typ: PostgreSQL type name or OID of a composite type
     40    :type typ: str or int
     41    :returns: a list of pairs of field names and types
     42    :rtype: list
    17143
    172     Fetch the next set of rows of a query result
     44.. method:: TypeCache.get_typecast(typ)
    17345
    174     :param size: the number of rows to be fetched
    175     :type size: int or None
    176     :param keep: if set to true, will keep the passed arraysize
    177     :tpye keep: bool
    178     :returns: the next set of rows of the query result
    179     :rtype: list of named tuples
     46    Get the cast function for the given database type
    18047
    181 Fetch the next set of rows of a query result, returning a list of named
    182 tuples. An empty sequence is returned when no more rows are available.
    183 The field names of the named tuple are the same as the column names of
    184 the database query as long as they are valid Python identifiers.
     48    :param str typ: PostgreSQL type name or type code
     49    :returns: the typecast function for the specified type
     50    :rtype: function or None
    18551
    186 The number of rows to fetch per call is specified by the *size* parameter.
    187 If it is not given, the cursor's :attr:`arraysize` determines the number of
    188 rows to be fetched. If you set the *keep* parameter to True, this is kept as
    189 new :attr:`arraysize`.
     52.. method:: TypeCache.set_typecast(typ, cast)
    19053
    191 The method tries to fetch as many rows as indicated by the *size* parameter.
    192 If this is not possible due to the specified number of rows not being
    193 available, fewer rows may be returned.
     54    Set a typecast function for the given database type(s)
    19455
    195 An :exc:`Error` (or subclass) exception is raised if the previous call to
    196 :meth:`Cursor.execute` or :meth:`Cursor.executemany` did not produce
    197 any result set or no call was issued yet.
     56    :param typ: PostgreSQL type name or type code, or list of such
     57    :type typ: str or list
     58    :param cast: the typecast function to be set for the specified type(s)
     59    :type typ: str or int
    19860
    199 Note there are performance considerations involved with the *size* parameter.
    200 For optimal performance, it is usually best to use the :attr:`arraysize`
    201 attribute. If the *size* parameter is used, then it is best for it to retain
    202 the same value from one :meth:`Cursor.fetchmany` call to the next.
     61.. method:: TypeCache.reset_typecast([typ])
    20362
    204 .. versionchanged:: 5.0
    205     Before version 5.0, this method returned ordinary tuples.
     63    Reset the typecasts for the specified (or all) type(s) to their defaults
    20664
    207 fetchall -- fetch all rows of the query result
    208 ----------------------------------------------
     65    :param str typ: PostgreSQL type name or type code, or list of such,
     66        or None to reset all typecast functions
     67    :type typ: str, list or None
    20968
    210 .. method:: Cursor.fetchall()
     69.. method:: TypeCache.typecast(typ, value)
    21170
    212     Fetch all (remaining) rows of a query result
     71    Cast the given value according to the given database type
    21372
    214     :returns: the set of all rows of the query result
    215     :rtype: list of named tuples
     73    :param str typ: PostgreSQL type name or type code
     74    :returns: the casted value
    21675
    217 Fetch all (remaining) rows of a query result, returning them as list of
    218 named tuples. The field names of the named tuple are the same as the column
    219 names of the database query as long as they are valid Python identifiers.
    220 
    221 Note that the cursor's :attr:`arraysize` attribute can affect the performance
    222 of this operation.
    223 
    224 .. versionchanged:: 5.0
    225     Before version 5.0, this method returned ordinary tuples.
    226 
    227 arraysize - the number of rows to fetch at a time
    228 -------------------------------------------------
    229 
    230 .. attribute:: Cursor.arraysize
    231 
    232     The number of rows to fetch at a time
    233 
    234 This read/write attribute specifies the number of rows to fetch at a time with
    235 :meth:`Cursor.fetchmany`. It defaults to 1, meaning to fetch a single row
    236 at a time.
    237 
    238 Methods and attributes that are not part of the standard
    239 --------------------------------------------------------
    24076
    24177.. note::
    24278
    243    The following methods and attributes are not part of the DB-API 2 standard.
    244 
    245 .. method:: Cursor.copy_from(stream, table, [format], [sep], [null], [size], [columns])
    246 
    247     Copy data from an input stream to the specified table
    248 
    249     :param stream: the input stream
    250         (must be a file-like object, a string or an iterable returning strings)
    251     :param str table: the name of a database table
    252     :param str format: the format of the data in the input stream,
    253         can be ``'text'`` (the default), ``'csv'``, or ``'binary'``
    254     :param str sep: a single character separator
    255         (the default is ``'\t'`` for text and ``','`` for csv)
    256     :param str null: the textual representation of the ``NULL`` value,
    257         can also be an empty string (the default is ``'\\N'``)
    258     :param int size: the size of the buffer when reading file-like objects
    259     :param list column: an optional list of column names
    260     :returns: the cursor, so you can chain commands
    261 
    262     :raises TypeError: parameters with wrong types
    263     :raises ValueError: invalid parameters
    264     :raises IOError: error when executing the copy operation
    265 
    266 This method can be used to copy data from an input stream on the client side
    267 to a database table on the server side using the ``COPY FROM`` command.
    268 The input stream can be provided in form of a file-like object (which must
    269 have a ``read()`` method), a string, or an iterable returning one row or
    270 multiple rows of input data on each iteration.
    271 
    272 The format must be text, csv or binary. The sep option sets the column
    273 separator (delimiter) used in the non binary formats. The null option sets
    274 the textual representation of ``NULL`` in the input.
    275 
    276 The size option sets the size of the buffer used when reading data from
    277 file-like objects.
    278 
    279 The copy operation can be restricted to a subset of columns. If no columns are
    280 specified, all of them will be copied.
    281 
    282 .. versionadded:: 5.0
    283 
    284 .. method:: Cursor.copy_to(stream, table, [format], [sep], [null], [decode], [columns])
    285 
    286     Copy data from the specified table to an output stream
    287 
    288     :param stream: the output stream (must be a file-like object or ``None``)
    289     :param str table: the name of a database table or a ``SELECT`` query
    290     :param str format: the format of the data in the input stream,
    291         can be ``'text'`` (the default), ``'csv'``, or ``'binary'``
    292     :param str sep: a single character separator
    293         (the default is ``'\t'`` for text and ``','`` for csv)
    294     :param str null: the textual representation of the ``NULL`` value,
    295         can also be an empty string (the default is ``'\\N'``)
    296     :param bool decode: whether decoded strings shall be returned
    297         for non-binary formats (the default is True in Python 3)
    298     :param list column: an optional list of column names
    299     :returns: a generator if stream is set to ``None``, otherwise the cursor
    300 
    301     :raises TypeError: parameters with wrong types
    302     :raises ValueError: invalid parameters
    303     :raises IOError: error when executing the copy operation
    304 
    305 This method can be used to copy data from a database table on the server side
    306 to an output stream on the client side using the ``COPY TO`` command.
    307 
    308 The output stream can be provided in form of a file-like object (which must
    309 have a ``write()`` method). Alternatively, if ``None`` is passed as the
    310 output stream, the method will return a generator yielding one row of output
    311 data on each iteration.
    312 
    313 Output will be returned as byte strings unless you set decode to true.
    314 
    315 Note that you can also use a ``SELECT`` query instead of the table name.
    316 
    317 The format must be text, csv or binary. The sep option sets the column
    318 separator (delimiter) used in the non binary formats. The null option sets
    319 the textual representation of ``NULL`` in the output.
    320 
    321 The copy operation can be restricted to a subset of columns. If no columns are
    322 specified, all of them will be copied.
    323 
    324 .. versionadded:: 5.0
    325 
    326 .. method:: Cursor.row_factory(row)
    327 
    328     Process rows before they are returned
    329 
    330     :param list row: the currently processed row of the result set
    331     :returns: the transformed row that the fetch methods shall return
    332 
    333 This method is used for processing result rows before returning them through
    334 one of the fetch methods. By default, rows are returned as named tuples.
    335 You can overwrite this method with a custom row factory if you want to
    336 return the rows as different kids of objects. This same row factory will then
    337 be used for all result sets. If you overwrite this method, the method
    338 :meth:`Cursor.build_row_factory` for creating row factories dynamically
    339 will be ignored.
    340 
    341 Note that named tuples are very efficient and can be easily converted to
    342 dicts (even OrderedDicts) by calling ``row._asdict()``. If you still want
    343 to return rows as dicts, you can create a custom cursor class like this::
    344 
    345     class DictCursor(pgdb.Cursor):
    346 
    347         def row_factory(self, row):
    348             return {key: value for key, value in zip(self.colnames, row)}
    349 
    350     cur = DictCursor(con)  # get one DictCursor instance or
    351     con.cursor_type = DictCursor  # always use DictCursor instances
    352 
    353 .. versionadded:: 4.0
    354 
    355 .. method:: Cursor.build_row_factory()
    356 
    357     Build a row factory based on the current description
    358 
    359     :returns: callable with the signature of :meth:`Cursor.row_factory`
    360 
    361 This method returns row factories for creating named tuples. It is called
    362 whenever a new result set is created, and :attr:`Cursor.row_factory` is
    363 then assigned the return value of this method. You can overwrite this method
    364 with a custom row factory builder if you want to use different row factories
    365 for different result sets. Otherwise, you can also simply overwrite the
    366 :meth:`Cursor.row_factory` method. This method will then be ignored.
    367 
    368 The default implementation that delivers rows as named tuples essentially
    369 looks like this::
    370 
    371     def build_row_factory(self):
    372         return namedtuple('Row', self.colnames, rename=True)._make
    373 
    374 .. versionadded:: 5.0
    375 
    376 .. attribute:: Cursor.colnames
    377 
    378     The list of columns names of the current result set
    379 
    380 The values in this list are the same values as the *name* elements
    381 in the :attr:`Cursor.description` attribute. Always use the latter
    382 if you want to remain standard compliant.
    383 
    384 .. versionadded:: 5.0
    385 
    386 .. attribute:: Cursor.coltypes
    387 
    388     The list of columns types of the current result set
    389 
    390 The values in this list are the same values as the *type_code* elements
    391 in the :attr:`Cursor.description` attribute. Always use the latter
    392 if you want to remain standard compliant.
    393 
    394 .. versionadded:: 5.0
     79    Note that the :class:`TypeCache` is always bound to a database connection.
     80    You can also get, set and reset typecast functions on a global level using
     81    the functions :func:`pgdb.get_typecast`, :func:`pgdb.set_typecast` and
     82    :func:`pgdb.reset_typecast`.  If you do this, the current database
     83    connections will continue to use their already cached typecast functions
     84    unless you call the :meth:`TypeCache.reset_typecast` method on the
     85    :attr:`Connection.type_cache` of the running connections.
  • trunk/docs/contents/pgdb/types.rst

    r796 r797  
    33
    44.. py:currentmodule:: pgdb
     5
     6.. _type_constructors:
    57
    68Type constructors
     
    7072    SQL ``NULL`` values are always represented by the Python *None* singleton
    7173    on input and output.
     74
     75.. _type_objects:
    7276
    7377Type objects
  • trunk/pgdb.py

    r796 r797  
    117117
    118118
    119 ### Internal Types Handling
     119### Internal Type Handling
    120120
    121121def decimal_type(decimal_type=None):
    122     """Get or set global type to be used for decimal values."""
     122    """Get or set global type to be used for decimal values.
     123
     124    Note that connections cache cast functions. To be sure a global change
     125    is picked up by a running connection, call con.type_cache.reset_typecast().
     126    """
    123127    global Decimal
    124128    if decimal_type is not None:
    125         _cast['numeric'] = Decimal = decimal_type
     129        Decimal = decimal_type
     130        set_typecast('numeric', decimal_type)
    126131    return Decimal
    127132
    128133
    129 def _cast_bool(value):
    130     return value[:1] in ('t', 'T')
    131 
    132 
    133 def _cast_money(value):
    134     return Decimal(''.join(filter(
    135         lambda v: v in '0123456789.-', value)))
    136 
    137 
    138 _cast = {'char': str, 'bpchar': str, 'name': str,
    139     'text': str, 'varchar': str,
    140     'bool': _cast_bool, 'bytea': unescape_bytea,
    141     'int2': int, 'int4': int, 'serial': int,
    142     'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
    143     'oid': long, 'oid8': long,
    144     'float4': float, 'float8': float,
    145     'numeric': Decimal, 'money': _cast_money,
    146     'record': cast_record}
    147 
    148 
    149 def _db_error(msg, cls=DatabaseError):
    150     """Return DatabaseError with empty sqlstate attribute."""
    151     error = cls(msg)
    152     error.sqlstate = None
    153     return error
    154 
    155 
    156 def _op_error(msg):
    157     """Return OperationalError."""
    158     return _db_error(msg, OperationalError)
     134def cast_bool(value):
     135    """Cast boolean value in database format to bool."""
     136    if value:
     137        return value[0] in ('t', 'T')
     138
     139
     140def cast_money(value):
     141    """Cast money value in database format to Decimal."""
     142    if value:
     143        value = value.replace('(', '-')
     144        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
     145
     146
     147class Typecasts(dict):
     148    """Dictionary mapping database types to typecast functions.
     149
     150    The cast functions must accept one Python object as an argument and
     151    convert that object to a string representation of the corresponding type
     152    in the database.  The Python None object is always converted to NULL,
     153    so the cast functions can assume they never get passed None as argument.
     154    However, they may get passed an empty string or a numeric null value.
     155    """
     156
     157    # the default cast functions
     158    # (str functions are ignored but have been added for faster access)
     159    defaults = {'char': str, 'bpchar': str, 'name': str,
     160        'text': str, 'varchar': str,
     161        'bool': cast_bool, 'bytea': unescape_bytea,
     162        'int2': int, 'int4': int, 'serial': int,
     163        'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
     164        'oid': long, 'oid8': long,
     165        'float4': float, 'float8': float,
     166        'numeric': Decimal, 'money': cast_money,
     167        'anyarray': cast_array, 'record': cast_record}
     168
     169    def __missing__(self, typ):
     170        """Create a cast function if it is not cached.
     171
     172        Note that this class never raises a KeyError,
     173        but return None when no special cast function exists.
     174        """
     175        cast = self.defaults.get(typ)
     176        if cast:
     177            # store default for faster access
     178            self[typ] = cast
     179        elif typ.startswith('_'):
     180            # create array cast
     181            base_cast = self[typ[1:]]
     182            cast = self.create_array_cast(base_cast)
     183            if base_cast:
     184                # store only if base type exists
     185                self[typ] = cast
     186        return cast
     187
     188    def get(self, typ, default=None):
     189        """Get the typecast function for the given database type."""
     190        return self[typ] or default
     191
     192    def set(self, typ, cast):
     193        """Set a typecast function for the specified database type(s)."""
     194        if isinstance(typ, basestring):
     195            typ = [typ]
     196        if cast is None:
     197            for t in typ:
     198                self.pop(t, None)
     199                self.pop('_%s' % t, None)
     200        else:
     201            if not callable(cast):
     202                raise TypeError("Cast parameter must be callable")
     203            for t in typ:
     204                self[t] = cast
     205                self.pop('_%s % t', None)
     206
     207    def reset(self, typ=None):
     208        """Reset the typecasts for the specified type(s) to their defaults.
     209
     210        When no type is specified, all typecasts will be reset.
     211        """
     212        defaults = self.defaults
     213        if typ is None:
     214            self.clear()
     215            self.update(defaults)
     216        else:
     217            if isinstance(typ, basestring):
     218                typ = [typ]
     219            for t in typ:
     220                self.set(t, defaults.get(t))
     221
     222    def create_array_cast(self, cast):
     223        """Create an array typecast for the given base cast."""
     224        return lambda v: cast_array(v, cast)
     225
     226    def create_record_cast(self, name, fields, casts):
     227        """Create a named record typecast for the given fields and casts."""
     228        record = namedtuple(name, fields)
     229        return lambda v: record(*cast_record(v, casts))
     230
     231
     232_typecasts = Typecasts()  # this is the global typecast dictionary
     233
     234
     235def get_typecast(typ):
     236    """Get the global typecast function for the given database type(s)."""
     237    return _typecasts.get(typ)
     238
     239
     240def set_typecast(typ, cast):
     241    """Set a global typecast function for the given database type(s).
     242
     243    Note that connections cache cast functions. To be sure a global change
     244    is picked up by a running connection, call con.type_cache.reset_typecast().
     245    """
     246    _typecasts.set(typ, cast)
     247
     248
     249def reset_typecast(typ=None):
     250    """Reset the global typecasts for the given type(s) to their default.
     251
     252    When no type is specified, all typecasts will be reset.
     253
     254    Note that connections cache cast functions. To be sure a global change
     255    is picked up by a running connection, call con.type_cache.reset_typecast().
     256    """
     257    _typecasts.reset(typ)
     258
     259
     260class LocalTypecasts(Typecasts):
     261    """Map typecasts, including local composite types, to cast functions."""
     262
     263    defaults = _typecasts
     264
     265    def __missing__(self, typ):
     266        """Create a cast function if it is not cached."""
     267        if typ.startswith('_'):
     268            base_cast = self[typ[1:]]
     269            cast = self.create_array_cast(base_cast)
     270            if base_cast:
     271                self[typ] = cast
     272        else:
     273            cast = self.defaults.get(typ)
     274            if cast:
     275                self[typ] = cast
     276            else:
     277                fields = self.get_fields(typ)
     278                if fields:
     279                    casts = [self[field.type] for field in fields]
     280                    fields = [field.name for field in fields]
     281                    cast = self.create_record_cast(typ, fields, casts)
     282                    self[typ] = cast
     283        return cast
     284
     285    def get_fields(self, typ):
     286        """Return the fields for the given record type.
     287
     288        This method will be replaced with a method that looks up the fields
     289        using the type cache of the connection.
     290        """
     291        return []
    159292
    160293
     
    178311        return self
    179312
    180 ColumnInfo = namedtuple('ColumnInfo', ['name', 'type'])
     313FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
    181314
    182315
     
    193326        self._escape_string = cnx.escape_string
    194327        self._src = cnx.source()
     328        self._typecasts = LocalTypecasts()
     329        self._typecasts.get_fields = self.get_fields
    195330
    196331    def __missing__(self, key):
     
    225360            return default
    226361
    227     def columns(self, key):
    228         """Get the names and types of the columns of composite types."""
    229         try:
    230             typ = self[key]
    231         except KeyError:
    232             return None  # this type is not known
    233         if typ.type != 'c' or not typ.relid:
     362    def get_fields(self, typ):
     363        """Get the names and types of the fields of composite types."""
     364        if not isinstance(typ, TypeCode):
     365            typ = self.get(typ)
     366            if not typ:
     367                return None
     368        if not typ.relid:
    234369            return None  # this type is not composite
    235370        self._src.execute("SELECT attname, atttypid"
    236371            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
    237372            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
    238         return [ColumnInfo(name, int(oid))
     373        return [FieldInfo(name, self.get(int(oid)))
    239374            for name, oid in self._src.fetch(-1)]
    240375
     376    def get_typecast(self, typ):
     377        """Get the typecast function for the given database type."""
     378        return self._typecasts.get(typ)
     379
     380    def set_typecast(self, typ, cast):
     381        """Set a typecast function for the specified database type(s)."""
     382        self._typecasts.set(typ, cast)
     383
     384    def reset_typecast(self, typ=None):
     385        """Reset the typecast function for the specified database type(s)."""
     386        self._typecasts.reset(typ)
     387
    241388    def typecast(self, typ, value):
    242         """Cast value according to database type."""
     389        """Cast the given value according to the given database type."""
    243390        if value is None:
    244391            # for NULL values, no typecast is necessary
    245392            return None
    246         cast = _cast.get(typ)
    247         if cast is str:
    248             return value  # no typecast necessary
    249         if cast is None:
    250             if typ.startswith('_'):
    251                 # cast as an array type
    252                 cast = _cast.get(typ[1:])
    253                 return cast_array(value, cast)
    254             # check whether this is a composite type
    255             cols = self.columns(typ)
    256             if cols:
    257                 getcast = self.getcast
    258                 cast = [getcast(col.type) for col in cols]
    259                 value = cast_record(value, cast)
    260                 fields = [col.name for col in cols]
    261                 record = namedtuple(typ, fields)
    262                 return record(*value)
    263             return value  # no typecast available or necessary
    264         else:
    265             return cast(value)
    266 
    267     def getcast(self, key):
    268         """Get a cast function for the given database type."""
    269         if isinstance(key, int):
    270             try:
    271                 typ = self[key]
    272             except KeyError:
    273                 return None
    274         else:
    275             typ = key
    276         typecast = self.typecast
    277         return lambda value: typecast(typ, value)
     393        cast = self.get_typecast(typ)
     394        if not cast or cast is str:
     395            # no typecast is necessary
     396            return value
     397        return cast(value)
    278398
    279399
     
    286406    def __getitem__(self, key):
    287407        return self.quote(super(_quotedict, self).__getitem__(key))
     408
     409
     410### Error messages
     411
     412def _db_error(msg, cls=DatabaseError):
     413    """Return DatabaseError with empty sqlstate attribute."""
     414    error = cls(msg)
     415    error.sqlstate = None
     416    return error
     417
     418
     419def _op_error(msg):
     420    """Return OperationalError."""
     421    return _db_error(msg, OperationalError)
    288422
    289423
     
    322456        self.close()
    323457
    324     def _quote(self, val):
     458    def _quote(self, value):
    325459        """Quote value depending on its type."""
    326         if val is None:
     460        if value is None:
    327461            return 'NULL'
    328         if isinstance(val, (datetime, date, time, timedelta, Json)):
    329             val = str(val)
    330         if isinstance(val, basestring):
    331             if isinstance(val, Binary):
    332                 val = self._cnx.escape_bytea(val)
     462        if isinstance(value, (datetime, date, time, timedelta, Json)):
     463            value = str(value)
     464        if isinstance(value, basestring):
     465            if isinstance(value, Binary):
     466                value = self._cnx.escape_bytea(value)
    333467                if bytes is not str:  # Python >= 3.0
    334                     val = val.decode('ascii')
     468                    value = value.decode('ascii')
    335469            else:
    336                 val = self._cnx.escape_string(val)
    337             return "'%s'" % val
    338         if isinstance(val, float):
    339             if isinf(val):
    340                 return "'-Infinity'" if val < 0 else "'Infinity'"
    341             if isnan(val):
     470                value = self._cnx.escape_string(value)
     471            return "'%s'" % value
     472        if isinstance(value, float):
     473            if isinf(value):
     474                return "'-Infinity'" if value < 0 else "'Infinity'"
     475            if isnan(value):
    342476                return "'NaN'"
    343             return val
    344         if isinstance(val, (int, long, Decimal)):
    345             return val
    346         if isinstance(val, list):
     477            return value
     478        if isinstance(value, (int, long, Decimal)):
     479            return value
     480        if isinstance(value, list):
    347481            # Quote value as an ARRAY constructor. This is better than using
    348482            # an array literal because it carries the information that this is
    349483            # an array and not a string.  One issue with this syntax is that
    350             # you need to add an explicit type cast when passing empty arrays.
     484            # you need to add an explicit typecast when passing empty arrays.
    351485            # The ARRAY keyword is actually only necessary at the top level.
    352486            q = self._quote
    353             return 'ARRAY[%s]' % ','.join(str(q(v)) for v in val)
    354         if isinstance(val, tuple):
     487            return 'ARRAY[%s]' % ','.join(str(q(v)) for v in value)
     488        if isinstance(value, tuple):
    355489            # Quote as a ROW constructor.  This is better than using a record
    356490            # literal because it carries the information that this is a record
     
    359493            # when the records has a single column which is not really useful.
    360494            q = self._quote
    361             return '(%s)' % ','.join(str(q(v)) for v in val)
     495            return '(%s)' % ','.join(str(q(v)) for v in value)
    362496        try:
    363             return val.__pg_repr__()
     497            value = value.__pg_repr__()
     498            if isinstance(value, (tuple, list)):
     499                value = self._quote(value)
     500            return value
    364501        except AttributeError:
    365502            raise InterfaceError(
    366                 'do not know how to handle type %s' % type(val))
    367 
     503                'Do not know how to adapt type %s' % type(value))
    368504
    369505    def _quoteparams(self, string, parameters):
     
    456592                    raise  # database provides error message
    457593                except Exception as err:
    458                     raise _op_error("can't start transaction")
     594                    raise _op_error("Can't start transaction")
    459595                self._dbcnx._tnx = True
    460596            for parameters in seq_of_parameters:
     
    471607        except Error as err:
    472608            raise _db_error(
    473                 "error in '%s': '%s' " % (sql, err), InterfaceError)
     609                "Error in '%s': '%s' " % (sql, err), InterfaceError)
    474610        except Exception as err:
    475             raise _op_error("internal error in '%s': %s" % (sql, err))
     611            raise _op_error("Internal error in '%s': %s" % (sql, err))
    476612        # then initialize result raw count and description
    477613        if self._src.resulttype == RESULT_DQL:
     
    561697        except AttributeError:
    562698            if size:
    563                 raise ValueError("size must only be set for file-like objects")
     699                raise ValueError("Size must only be set for file-like objects")
    564700            if binary_format:
    565701                input_type = bytes
     
    630766        if format is not None:
    631767            if not isinstance(format, basestring):
    632                 raise TypeError("format option must be be a string")
     768                raise TypeError("The frmat option must be be a string")
    633769            if format not in ('text', 'csv', 'binary'):
    634                 raise ValueError("invalid format")
     770                raise ValueError("Invalid format")
    635771            options.append('format %s' % (format,))
    636772        if sep is not None:
    637773            if not isinstance(sep, basestring):
    638                 raise TypeError("sep option must be a string")
     774                raise TypeError("The sep option must be a string")
    639775            if format == 'binary':
    640                 raise ValueError("sep is not allowed with binary format")
     776                raise ValueError(
     777                    "The sep option is not allowed with binary format")
    641778            if len(sep) != 1:
    642                 raise ValueError("sep must be a single one-byte character")
     779                raise ValueError(
     780                    "The sep option must be a single one-byte character")
    643781            options.append('delimiter %s')
    644782            params.append(sep)
    645783        if null is not None:
    646784            if not isinstance(null, basestring):
    647                 raise TypeError("null option must be a string")
     785                raise TypeError("The null option must be a string")
    648786            options.append('null %s')
    649787            params.append(null)
     
    697835                write = stream.write
    698836            except AttributeError:
    699                 raise TypeError("need an output stream to copy to")
     837                raise TypeError("Need an output stream to copy to")
    700838        if not table or not isinstance(table, basestring):
    701             raise TypeError("need a table to copy to")
     839            raise TypeError("Need a table to copy to")
    702840        if table.lower().startswith('select'):
    703841            if columns:
    704                 raise ValueError("columns must be specified in the query")
     842                raise ValueError("Columns must be specified in the query")
    705843            table = '(%s)' % (table,)
    706844        else:
     
    711849        if format is not None:
    712850            if not isinstance(format, basestring):
    713                 raise TypeError("format option must be a string")
     851                raise TypeError("The format option must be a string")
    714852            if format not in ('text', 'csv', 'binary'):
    715                 raise ValueError("invalid format")
     853                raise ValueError("Invalid format")
    716854            options.append('format %s' % (format,))
    717855        if sep is not None:
    718856            if not isinstance(sep, basestring):
    719                 raise TypeError("sep option must be a string")
     857                raise TypeError("The sep option must be a string")
    720858            if binary_format:
    721                 raise ValueError("sep is not allowed with binary format")
     859                raise ValueError(
     860                    "The sep option is not allowed with binary format")
    722861            if len(sep) != 1:
    723                 raise ValueError("sep must be a single one-byte character")
     862                raise ValueError(
     863                    "The sep option must be a single one-byte character")
    724864            options.append('delimiter %s')
    725865            params.append(sep)
    726866        if null is not None:
    727867            if not isinstance(null, basestring):
    728                 raise TypeError("null option must be a string")
     868                raise TypeError("The null option must be a string")
    729869            options.append('null %s')
    730870            params.append(null)
     
    736876        else:
    737877            if not isinstance(decode, (int, bool)):
    738                 raise TypeError("decode option must be a boolean")
     878                raise TypeError("The decode option must be a boolean")
    739879            if decode and binary_format:
    740                 raise ValueError("decode is not allowed with binary format")
     880                raise ValueError(
     881                    "The decode option is not allowed with binary format")
    741882        if columns:
    742883            if not isinstance(columns, basestring):
     
    788929    def nextset():
    789930        """Not supported."""
    790         raise NotSupportedError("nextset() is not supported")
     931        raise NotSupportedError("The nextset() method is not supported")
    791932
    792933    @staticmethod
     
    8731014            self._cnx.source()
    8741015        except Exception:
    875             raise _op_error("invalid connection")
     1016            raise _op_error("Invalid connection")
    8761017
    8771018    def __enter__(self):
     
    9031044            self._cnx = None
    9041045        else:
    905             raise _op_error("connection has been closed")
     1046            raise _op_error("Connection has been closed")
    9061047
    9071048    def commit(self):
     
    9151056                    raise
    9161057                except Exception:
    917                     raise _op_error("can't commit")
    918         else:
    919             raise _op_error("connection has been closed")
     1058                    raise _op_error("Can't commit")
     1059        else:
     1060            raise _op_error("Connection has been closed")
    9201061
    9211062    def rollback(self):
     
    9291070                    raise
    9301071                except Exception:
    931                     raise _op_error("can't rollback")
    932         else:
    933             raise _op_error("connection has been closed")
     1072                    raise _op_error("Can't rollback")
     1073        else:
     1074            raise _op_error("Connection has been closed")
    9341075
    9351076    def cursor(self):
     
    9391080                return self.cursor_type(self)
    9401081            except Exception:
    941                 raise _op_error("invalid connection")
    942         else:
    943             raise _op_error("connection has been closed")
     1082                raise _op_error("Invalid connection")
     1083        else:
     1084            raise _op_error("Connection has been closed")
    9441085
    9451086    if shortcutmethods:  # otherwise do not implement and document this
  • trunk/tests/test_dbapi20.py

    r796 r797  
    291291            self.assertIsNone(d.null_ok)
    292292
    293     def test_type_cache(self):
    294         con = self._connect()
    295         cur = con.cursor()
    296         type_cache = con.type_cache
    297         self.assertNotIn('numeric', type_cache)
    298         type_info = type_cache['numeric']
    299         self.assertIn('numeric', type_cache)
    300         self.assertEqual(type_info, 'numeric')
    301         self.assertEqual(type_info.oid, 1700)
    302         self.assertEqual(type_info.type, 'b')  # base
    303         self.assertEqual(type_info.category, 'N')  # numeric
    304         self.assertEqual(type_info.delim, ',')
    305         self.assertIs(con.type_cache[1700], type_info)
    306         self.assertNotIn('pg_type', type_cache)
    307         type_info = type_cache['pg_type']
    308         self.assertIn('numeric', type_cache)
    309         self.assertEqual(type_info.type, 'c')  # composite
    310         self.assertEqual(type_info.category, 'C')  # composite
    311         cols = type_cache.columns('pg_type')
    312         self.assertEqual(cols[0].name, 'typname')
    313         typname = type_cache[cols[0].type]
    314         self.assertEqual(typname, 'name')
    315         self.assertEqual(typname.type, 'b')  # base
    316         self.assertEqual(typname.category, 'S')  # string
    317         self.assertEqual(cols[3].name, 'typlen')
    318         typlen = type_cache[cols[3].type]
    319         self.assertEqual(typlen, 'int2')
    320         self.assertEqual(typlen.type, 'b')  # base
    321         self.assertEqual(typlen.category, 'N')  # numeric
    322         cur.close()
    323         cur = con.cursor()
    324         type_cache = con.type_cache
    325         self.assertIn('numeric', type_cache)
    326         cur.close()
    327         con.close()
    328         con = self._connect()
    329         cur = con.cursor()
    330         type_cache = con.type_cache
    331         self.assertNotIn('pg_type', type_cache)
    332         self.assertEqual(type_cache.get('pg_type'), type_info)
    333         self.assertIn('pg_type', type_cache)
    334         self.assertIsNone(type_cache.get(
    335             self.table_prefix + '_surely_does_not_exist'))
    336         cur.close()
    337         con.close()
     293    def test_type_cache_info(self):
     294        con = self._connect()
     295        try:
     296            cur = con.cursor()
     297            type_cache = con.type_cache
     298            self.assertNotIn('numeric', type_cache)
     299            type_info = type_cache['numeric']
     300            self.assertIn('numeric', type_cache)
     301            self.assertEqual(type_info, 'numeric')
     302            self.assertEqual(type_info.oid, 1700)
     303            self.assertEqual(type_info.len, -1)
     304            self.assertEqual(type_info.type, 'b')  # base
     305            self.assertEqual(type_info.category, 'N')  # numeric
     306            self.assertEqual(type_info.delim, ',')
     307            self.assertEqual(type_info.relid, 0)
     308            self.assertIs(con.type_cache[1700], type_info)
     309            self.assertNotIn('pg_type', type_cache)
     310            type_info = type_cache['pg_type']
     311            self.assertIn('numeric', type_cache)
     312            self.assertEqual(type_info.type, 'c')  # composite
     313            self.assertEqual(type_info.category, 'C')  # composite
     314            cols = type_cache.get_fields('pg_type')
     315            self.assertEqual(cols[0].name, 'typname')
     316            typname = type_cache[cols[0].type]
     317            self.assertEqual(typname, 'name')
     318            self.assertEqual(typname.type, 'b')  # base
     319            self.assertEqual(typname.category, 'S')  # string
     320            self.assertEqual(cols[3].name, 'typlen')
     321            typlen = type_cache[cols[3].type]
     322            self.assertEqual(typlen, 'int2')
     323            self.assertEqual(typlen.type, 'b')  # base
     324            self.assertEqual(typlen.category, 'N')  # numeric
     325            cur.close()
     326            cur = con.cursor()
     327            type_cache = con.type_cache
     328            self.assertIn('numeric', type_cache)
     329            cur.close()
     330        finally:
     331            con.close()
     332        con = self._connect()
     333        try:
     334            cur = con.cursor()
     335            type_cache = con.type_cache
     336            self.assertNotIn('pg_type', type_cache)
     337            self.assertEqual(type_cache.get('pg_type'), type_info)
     338            self.assertIn('pg_type', type_cache)
     339            self.assertIsNone(type_cache.get(
     340                self.table_prefix + '_surely_does_not_exist'))
     341            cur.close()
     342        finally:
     343            con.close()
     344
     345    def test_type_cache_typecast(self):
     346        con = self._connect()
     347        try:
     348            cur = con.cursor()
     349            type_cache = con.type_cache
     350            self.assertIs(type_cache.get_typecast('int4'), int)
     351            cast_int = lambda v: 'int(%s)' % v
     352            type_cache.set_typecast('int4', cast_int)
     353            query = 'select 2::int2, 4::int4, 8::int8'
     354            cur.execute(query)
     355            i2, i4, i8 = cur.fetchone()
     356            self.assertEqual(i2, 2)
     357            self.assertEqual(i4, 'int(4)')
     358            self.assertEqual(i8, 8)
     359            self.assertEqual(type_cache.typecast('int4', 42), 'int(42)')
     360            type_cache.set_typecast(['int2', 'int8'], cast_int)
     361            cur.execute(query)
     362            i2, i4, i8 = cur.fetchone()
     363            self.assertEqual(i2, 'int(2)')
     364            self.assertEqual(i4, 'int(4)')
     365            self.assertEqual(i8, 'int(8)')
     366            type_cache.reset_typecast('int4')
     367            cur.execute(query)
     368            i2, i4, i8 = cur.fetchone()
     369            self.assertEqual(i2, 'int(2)')
     370            self.assertEqual(i4, 4)
     371            self.assertEqual(i8, 'int(8)')
     372            type_cache.reset_typecast(['int2', 'int8'])
     373            cur.execute(query)
     374            i2, i4, i8 = cur.fetchone()
     375            self.assertEqual(i2, 2)
     376            self.assertEqual(i4, 4)
     377            self.assertEqual(i8, 8)
     378            type_cache.set_typecast(['int2', 'int8'], cast_int)
     379            cur.execute(query)
     380            i2, i4, i8 = cur.fetchone()
     381            self.assertEqual(i2, 'int(2)')
     382            self.assertEqual(i4, 4)
     383            self.assertEqual(i8, 'int(8)')
     384            type_cache.reset_typecast()
     385            cur.execute(query)
     386            i2, i4, i8 = cur.fetchone()
     387            self.assertEqual(i2, 2)
     388            self.assertEqual(i4, 4)
     389            self.assertEqual(i8, 8)
     390            cur.close()
     391        finally:
     392            con.close()
    338393
    339394    def test_cursor_iteration(self):
     
    533588            self.assertEqual(type_code, pgdb.RECORD)
    534589            self.assertNotEqual(type_code, pgdb.ARRAY)
    535             columns = con.type_cache.columns(type_code)
     590            columns = con.type_cache.get_fields(type_code)
    536591            self.assertEqual(columns[0].name, 'name')
    537592            self.assertEqual(columns[1].name, 'age')
     
    598653        try:
    599654            cur = con.cursor()
    600             self.assertTrue(pgdb.decimal_type(int) is int)
    601             cur.execute('select 42')
    602             self.assertEqual(cur.description[0].type_code, pgdb.INTEGER)
     655            # change decimal type globally to int
     656            int_type = lambda v: int(float(v))
     657            self.assertTrue(pgdb.decimal_type(int_type) is int_type)
     658            cur.execute('select 4.25')
     659            self.assertEqual(cur.description[0].type_code, pgdb.NUMBER)
    603660            value = cur.fetchone()[0]
    604661            self.assertTrue(isinstance(value, int))
    605             self.assertEqual(value, 42)
     662            self.assertEqual(value, 4)
     663            # change decimal type again to float
    606664            self.assertTrue(pgdb.decimal_type(float) is float)
    607665            cur.execute('select 4.25')
    608666            self.assertEqual(cur.description[0].type_code, pgdb.NUMBER)
    609667            value = cur.fetchone()[0]
     668            # the connection still uses the old setting
     669            self.assertTrue(isinstance(value, int))
     670            # bust the cache for type functions for the connection
     671            con.type_cache.reset_typecast()
     672            cur.execute('select 4.25')
     673            self.assertEqual(cur.description[0].type_code, pgdb.NUMBER)
     674            value = cur.fetchone()[0]
     675            # now the connection uses the new setting
    610676            self.assertTrue(isinstance(value, float))
    611677            self.assertEqual(value, 4.25)
     
    614680            pgdb.decimal_type(decimal_type)
    615681        self.assertTrue(pgdb.decimal_type() is decimal_type)
     682
     683    def test_global_typecast(self):
     684        try:
     685            query = 'select 2::int2, 4::int4, 8::int8'
     686            self.assertIs(pgdb.get_typecast('int4'), int)
     687            cast_int = lambda v: 'int(%s)' % v
     688            pgdb.set_typecast('int4', cast_int)
     689            con = self._connect()
     690            try:
     691                i2, i4, i8 = con.cursor().execute(query).fetchone()
     692            finally:
     693                con.close()
     694            self.assertEqual(i2, 2)
     695            self.assertEqual(i4, 'int(4)')
     696            self.assertEqual(i8, 8)
     697            pgdb.set_typecast(['int2', 'int8'], cast_int)
     698            con = self._connect()
     699            try:
     700                i2, i4, i8 = con.cursor().execute(query).fetchone()
     701            finally:
     702                con.close()
     703            self.assertEqual(i2, 'int(2)')
     704            self.assertEqual(i4, 'int(4)')
     705            self.assertEqual(i8, 'int(8)')
     706            pgdb.reset_typecast('int4')
     707            con = self._connect()
     708            try:
     709                i2, i4, i8 = con.cursor().execute(query).fetchone()
     710            finally:
     711                con.close()
     712            self.assertEqual(i2, 'int(2)')
     713            self.assertEqual(i4, 4)
     714            self.assertEqual(i8, 'int(8)')
     715            pgdb.reset_typecast(['int2', 'int8'])
     716            con = self._connect()
     717            try:
     718                i2, i4, i8 = con.cursor().execute(query).fetchone()
     719            finally:
     720                con.close()
     721            self.assertEqual(i2, 2)
     722            self.assertEqual(i4, 4)
     723            self.assertEqual(i8, 8)
     724            pgdb.set_typecast(['int2', 'int8'], cast_int)
     725            con = self._connect()
     726            try:
     727                i2, i4, i8 = con.cursor().execute(query).fetchone()
     728            finally:
     729                con.close()
     730            self.assertEqual(i2, 'int(2)')
     731            self.assertEqual(i4, 4)
     732            self.assertEqual(i8, 'int(8)')
     733        finally:
     734            pgdb.reset_typecast()
     735        con = self._connect()
     736        try:
     737            i2, i4, i8 = con.cursor().execute(query).fetchone()
     738        finally:
     739            con.close()
     740        self.assertEqual(i2, 2)
     741        self.assertEqual(i4, 4)
     742        self.assertEqual(i8, 8)
    616743
    617744    def test_unicode_with_utf8(self):
Note: See TracChangeset for help on using the changeset viewer.