Changeset 797 for trunk/pgdb.py


Ignore:
Timestamp:
Jan 29, 2016, 5:43:22 PM (4 years ago)
Author:
cito
Message:

Cache typecast functions and make them configurable

The typecast functions used by the pgdb module are now cached
using a local and a global Typecasts class. The local cache is
bound to the connection and knows how to cast composite types.

Also added functions that allow registering custom typecast
functions on the global and local level.

Also added a chapter on type adaptation and casting to the docs.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/pgdb.py

    r796 r797  
    117117
    118118
    119 ### Internal Types Handling
     119### Internal Type Handling
    120120
    121121def decimal_type(decimal_type=None):
    122     """Get or set global type to be used for decimal values."""
     122    """Get or set global type to be used for decimal values.
     123
     124    Note that connections cache cast functions. To be sure a global change
     125    is picked up by a running connection, call con.type_cache.reset_typecast().
     126    """
    123127    global Decimal
    124128    if decimal_type is not None:
    125         _cast['numeric'] = Decimal = decimal_type
     129        Decimal = decimal_type
     130        set_typecast('numeric', decimal_type)
    126131    return Decimal
    127132
    128133
    129 def _cast_bool(value):
    130     return value[:1] in ('t', 'T')
    131 
    132 
    133 def _cast_money(value):
    134     return Decimal(''.join(filter(
    135         lambda v: v in '0123456789.-', value)))
    136 
    137 
    138 _cast = {'char': str, 'bpchar': str, 'name': str,
    139     'text': str, 'varchar': str,
    140     'bool': _cast_bool, 'bytea': unescape_bytea,
    141     'int2': int, 'int4': int, 'serial': int,
    142     'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
    143     'oid': long, 'oid8': long,
    144     'float4': float, 'float8': float,
    145     'numeric': Decimal, 'money': _cast_money,
    146     'record': cast_record}
    147 
    148 
    149 def _db_error(msg, cls=DatabaseError):
    150     """Return DatabaseError with empty sqlstate attribute."""
    151     error = cls(msg)
    152     error.sqlstate = None
    153     return error
    154 
    155 
    156 def _op_error(msg):
    157     """Return OperationalError."""
    158     return _db_error(msg, OperationalError)
     134def cast_bool(value):
     135    """Cast boolean value in database format to bool."""
     136    if value:
     137        return value[0] in ('t', 'T')
     138
     139
     140def cast_money(value):
     141    """Cast money value in database format to Decimal."""
     142    if value:
     143        value = value.replace('(', '-')
     144        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
     145
     146
     147class Typecasts(dict):
     148    """Dictionary mapping database types to typecast functions.
     149
     150    The cast functions must accept one Python object as an argument and
     151    convert that object to a string representation of the corresponding type
     152    in the database.  The Python None object is always converted to NULL,
     153    so the cast functions can assume they never get passed None as argument.
     154    However, they may get passed an empty string or a numeric null value.
     155    """
     156
     157    # the default cast functions
     158    # (str functions are ignored but have been added for faster access)
     159    defaults = {'char': str, 'bpchar': str, 'name': str,
     160        'text': str, 'varchar': str,
     161        'bool': cast_bool, 'bytea': unescape_bytea,
     162        'int2': int, 'int4': int, 'serial': int,
     163        'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
     164        'oid': long, 'oid8': long,
     165        'float4': float, 'float8': float,
     166        'numeric': Decimal, 'money': cast_money,
     167        'anyarray': cast_array, 'record': cast_record}
     168
     169    def __missing__(self, typ):
     170        """Create a cast function if it is not cached.
     171
     172        Note that this class never raises a KeyError,
     173        but return None when no special cast function exists.
     174        """
     175        cast = self.defaults.get(typ)
     176        if cast:
     177            # store default for faster access
     178            self[typ] = cast
     179        elif typ.startswith('_'):
     180            # create array cast
     181            base_cast = self[typ[1:]]
     182            cast = self.create_array_cast(base_cast)
     183            if base_cast:
     184                # store only if base type exists
     185                self[typ] = cast
     186        return cast
     187
     188    def get(self, typ, default=None):
     189        """Get the typecast function for the given database type."""
     190        return self[typ] or default
     191
     192    def set(self, typ, cast):
     193        """Set a typecast function for the specified database type(s)."""
     194        if isinstance(typ, basestring):
     195            typ = [typ]
     196        if cast is None:
     197            for t in typ:
     198                self.pop(t, None)
     199                self.pop('_%s' % t, None)
     200        else:
     201            if not callable(cast):
     202                raise TypeError("Cast parameter must be callable")
     203            for t in typ:
     204                self[t] = cast
     205                self.pop('_%s % t', None)
     206
     207    def reset(self, typ=None):
     208        """Reset the typecasts for the specified type(s) to their defaults.
     209
     210        When no type is specified, all typecasts will be reset.
     211        """
     212        defaults = self.defaults
     213        if typ is None:
     214            self.clear()
     215            self.update(defaults)
     216        else:
     217            if isinstance(typ, basestring):
     218                typ = [typ]
     219            for t in typ:
     220                self.set(t, defaults.get(t))
     221
     222    def create_array_cast(self, cast):
     223        """Create an array typecast for the given base cast."""
     224        return lambda v: cast_array(v, cast)
     225
     226    def create_record_cast(self, name, fields, casts):
     227        """Create a named record typecast for the given fields and casts."""
     228        record = namedtuple(name, fields)
     229        return lambda v: record(*cast_record(v, casts))
     230
     231
     232_typecasts = Typecasts()  # this is the global typecast dictionary
     233
     234
     235def get_typecast(typ):
     236    """Get the global typecast function for the given database type(s)."""
     237    return _typecasts.get(typ)
     238
     239
     240def set_typecast(typ, cast):
     241    """Set a global typecast function for the given database type(s).
     242
     243    Note that connections cache cast functions. To be sure a global change
     244    is picked up by a running connection, call con.type_cache.reset_typecast().
     245    """
     246    _typecasts.set(typ, cast)
     247
     248
     249def reset_typecast(typ=None):
     250    """Reset the global typecasts for the given type(s) to their default.
     251
     252    When no type is specified, all typecasts will be reset.
     253
     254    Note that connections cache cast functions. To be sure a global change
     255    is picked up by a running connection, call con.type_cache.reset_typecast().
     256    """
     257    _typecasts.reset(typ)
     258
     259
     260class LocalTypecasts(Typecasts):
     261    """Map typecasts, including local composite types, to cast functions."""
     262
     263    defaults = _typecasts
     264
     265    def __missing__(self, typ):
     266        """Create a cast function if it is not cached."""
     267        if typ.startswith('_'):
     268            base_cast = self[typ[1:]]
     269            cast = self.create_array_cast(base_cast)
     270            if base_cast:
     271                self[typ] = cast
     272        else:
     273            cast = self.defaults.get(typ)
     274            if cast:
     275                self[typ] = cast
     276            else:
     277                fields = self.get_fields(typ)
     278                if fields:
     279                    casts = [self[field.type] for field in fields]
     280                    fields = [field.name for field in fields]
     281                    cast = self.create_record_cast(typ, fields, casts)
     282                    self[typ] = cast
     283        return cast
     284
     285    def get_fields(self, typ):
     286        """Return the fields for the given record type.
     287
     288        This method will be replaced with a method that looks up the fields
     289        using the type cache of the connection.
     290        """
     291        return []
    159292
    160293
     
    178311        return self
    179312
    180 ColumnInfo = namedtuple('ColumnInfo', ['name', 'type'])
     313FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
    181314
    182315
     
    193326        self._escape_string = cnx.escape_string
    194327        self._src = cnx.source()
     328        self._typecasts = LocalTypecasts()
     329        self._typecasts.get_fields = self.get_fields
    195330
    196331    def __missing__(self, key):
     
    225360            return default
    226361
    227     def columns(self, key):
    228         """Get the names and types of the columns of composite types."""
    229         try:
    230             typ = self[key]
    231         except KeyError:
    232             return None  # this type is not known
    233         if typ.type != 'c' or not typ.relid:
     362    def get_fields(self, typ):
     363        """Get the names and types of the fields of composite types."""
     364        if not isinstance(typ, TypeCode):
     365            typ = self.get(typ)
     366            if not typ:
     367                return None
     368        if not typ.relid:
    234369            return None  # this type is not composite
    235370        self._src.execute("SELECT attname, atttypid"
    236371            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
    237372            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
    238         return [ColumnInfo(name, int(oid))
     373        return [FieldInfo(name, self.get(int(oid)))
    239374            for name, oid in self._src.fetch(-1)]
    240375
     376    def get_typecast(self, typ):
     377        """Get the typecast function for the given database type."""
     378        return self._typecasts.get(typ)
     379
     380    def set_typecast(self, typ, cast):
     381        """Set a typecast function for the specified database type(s)."""
     382        self._typecasts.set(typ, cast)
     383
     384    def reset_typecast(self, typ=None):
     385        """Reset the typecast function for the specified database type(s)."""
     386        self._typecasts.reset(typ)
     387
    241388    def typecast(self, typ, value):
    242         """Cast value according to database type."""
     389        """Cast the given value according to the given database type."""
    243390        if value is None:
    244391            # for NULL values, no typecast is necessary
    245392            return None
    246         cast = _cast.get(typ)
    247         if cast is str:
    248             return value  # no typecast necessary
    249         if cast is None:
    250             if typ.startswith('_'):
    251                 # cast as an array type
    252                 cast = _cast.get(typ[1:])
    253                 return cast_array(value, cast)
    254             # check whether this is a composite type
    255             cols = self.columns(typ)
    256             if cols:
    257                 getcast = self.getcast
    258                 cast = [getcast(col.type) for col in cols]
    259                 value = cast_record(value, cast)
    260                 fields = [col.name for col in cols]
    261                 record = namedtuple(typ, fields)
    262                 return record(*value)
    263             return value  # no typecast available or necessary
    264         else:
    265             return cast(value)
    266 
    267     def getcast(self, key):
    268         """Get a cast function for the given database type."""
    269         if isinstance(key, int):
    270             try:
    271                 typ = self[key]
    272             except KeyError:
    273                 return None
    274         else:
    275             typ = key
    276         typecast = self.typecast
    277         return lambda value: typecast(typ, value)
     393        cast = self.get_typecast(typ)
     394        if not cast or cast is str:
     395            # no typecast is necessary
     396            return value
     397        return cast(value)
    278398
    279399
     
    286406    def __getitem__(self, key):
    287407        return self.quote(super(_quotedict, self).__getitem__(key))
     408
     409
     410### Error messages
     411
     412def _db_error(msg, cls=DatabaseError):
     413    """Return DatabaseError with empty sqlstate attribute."""
     414    error = cls(msg)
     415    error.sqlstate = None
     416    return error
     417
     418
     419def _op_error(msg):
     420    """Return OperationalError."""
     421    return _db_error(msg, OperationalError)
    288422
    289423
     
    322456        self.close()
    323457
    324     def _quote(self, val):
     458    def _quote(self, value):
    325459        """Quote value depending on its type."""
    326         if val is None:
     460        if value is None:
    327461            return 'NULL'
    328         if isinstance(val, (datetime, date, time, timedelta, Json)):
    329             val = str(val)
    330         if isinstance(val, basestring):
    331             if isinstance(val, Binary):
    332                 val = self._cnx.escape_bytea(val)
     462        if isinstance(value, (datetime, date, time, timedelta, Json)):
     463            value = str(value)
     464        if isinstance(value, basestring):
     465            if isinstance(value, Binary):
     466                value = self._cnx.escape_bytea(value)
    333467                if bytes is not str:  # Python >= 3.0
    334                     val = val.decode('ascii')
     468                    value = value.decode('ascii')
    335469            else:
    336                 val = self._cnx.escape_string(val)
    337             return "'%s'" % val
    338         if isinstance(val, float):
    339             if isinf(val):
    340                 return "'-Infinity'" if val < 0 else "'Infinity'"
    341             if isnan(val):
     470                value = self._cnx.escape_string(value)
     471            return "'%s'" % value
     472        if isinstance(value, float):
     473            if isinf(value):
     474                return "'-Infinity'" if value < 0 else "'Infinity'"
     475            if isnan(value):
    342476                return "'NaN'"
    343             return val
    344         if isinstance(val, (int, long, Decimal)):
    345             return val
    346         if isinstance(val, list):
     477            return value
     478        if isinstance(value, (int, long, Decimal)):
     479            return value
     480        if isinstance(value, list):
    347481            # Quote value as an ARRAY constructor. This is better than using
    348482            # an array literal because it carries the information that this is
    349483            # an array and not a string.  One issue with this syntax is that
    350             # you need to add an explicit type cast when passing empty arrays.
     484            # you need to add an explicit typecast when passing empty arrays.
    351485            # The ARRAY keyword is actually only necessary at the top level.
    352486            q = self._quote
    353             return 'ARRAY[%s]' % ','.join(str(q(v)) for v in val)
    354         if isinstance(val, tuple):
     487            return 'ARRAY[%s]' % ','.join(str(q(v)) for v in value)
     488        if isinstance(value, tuple):
    355489            # Quote as a ROW constructor.  This is better than using a record
    356490            # literal because it carries the information that this is a record
     
    359493            # when the records has a single column which is not really useful.
    360494            q = self._quote
    361             return '(%s)' % ','.join(str(q(v)) for v in val)
     495            return '(%s)' % ','.join(str(q(v)) for v in value)
    362496        try:
    363             return val.__pg_repr__()
     497            value = value.__pg_repr__()
     498            if isinstance(value, (tuple, list)):
     499                value = self._quote(value)
     500            return value
    364501        except AttributeError:
    365502            raise InterfaceError(
    366                 'do not know how to handle type %s' % type(val))
    367 
     503                'Do not know how to adapt type %s' % type(value))
    368504
    369505    def _quoteparams(self, string, parameters):
     
    456592                    raise  # database provides error message
    457593                except Exception as err:
    458                     raise _op_error("can't start transaction")
     594                    raise _op_error("Can't start transaction")
    459595                self._dbcnx._tnx = True
    460596            for parameters in seq_of_parameters:
     
    471607        except Error as err:
    472608            raise _db_error(
    473                 "error in '%s': '%s' " % (sql, err), InterfaceError)
     609                "Error in '%s': '%s' " % (sql, err), InterfaceError)
    474610        except Exception as err:
    475             raise _op_error("internal error in '%s': %s" % (sql, err))
     611            raise _op_error("Internal error in '%s': %s" % (sql, err))
    476612        # then initialize result raw count and description
    477613        if self._src.resulttype == RESULT_DQL:
     
    561697        except AttributeError:
    562698            if size:
    563                 raise ValueError("size must only be set for file-like objects")
     699                raise ValueError("Size must only be set for file-like objects")
    564700            if binary_format:
    565701                input_type = bytes
     
    630766        if format is not None:
    631767            if not isinstance(format, basestring):
    632                 raise TypeError("format option must be be a string")
     768                raise TypeError("The frmat option must be be a string")
    633769            if format not in ('text', 'csv', 'binary'):
    634                 raise ValueError("invalid format")
     770                raise ValueError("Invalid format")
    635771            options.append('format %s' % (format,))
    636772        if sep is not None:
    637773            if not isinstance(sep, basestring):
    638                 raise TypeError("sep option must be a string")
     774                raise TypeError("The sep option must be a string")
    639775            if format == 'binary':
    640                 raise ValueError("sep is not allowed with binary format")
     776                raise ValueError(
     777                    "The sep option is not allowed with binary format")
    641778            if len(sep) != 1:
    642                 raise ValueError("sep must be a single one-byte character")
     779                raise ValueError(
     780                    "The sep option must be a single one-byte character")
    643781            options.append('delimiter %s')
    644782            params.append(sep)
    645783        if null is not None:
    646784            if not isinstance(null, basestring):
    647                 raise TypeError("null option must be a string")
     785                raise TypeError("The null option must be a string")
    648786            options.append('null %s')
    649787            params.append(null)
     
    697835                write = stream.write
    698836            except AttributeError:
    699                 raise TypeError("need an output stream to copy to")
     837                raise TypeError("Need an output stream to copy to")
    700838        if not table or not isinstance(table, basestring):
    701             raise TypeError("need a table to copy to")
     839            raise TypeError("Need a table to copy to")
    702840        if table.lower().startswith('select'):
    703841            if columns:
    704                 raise ValueError("columns must be specified in the query")
     842                raise ValueError("Columns must be specified in the query")
    705843            table = '(%s)' % (table,)
    706844        else:
     
    711849        if format is not None:
    712850            if not isinstance(format, basestring):
    713                 raise TypeError("format option must be a string")
     851                raise TypeError("The format option must be a string")
    714852            if format not in ('text', 'csv', 'binary'):
    715                 raise ValueError("invalid format")
     853                raise ValueError("Invalid format")
    716854            options.append('format %s' % (format,))
    717855        if sep is not None:
    718856            if not isinstance(sep, basestring):
    719                 raise TypeError("sep option must be a string")
     857                raise TypeError("The sep option must be a string")
    720858            if binary_format:
    721                 raise ValueError("sep is not allowed with binary format")
     859                raise ValueError(
     860                    "The sep option is not allowed with binary format")
    722861            if len(sep) != 1:
    723                 raise ValueError("sep must be a single one-byte character")
     862                raise ValueError(
     863                    "The sep option must be a single one-byte character")
    724864            options.append('delimiter %s')
    725865            params.append(sep)
    726866        if null is not None:
    727867            if not isinstance(null, basestring):
    728                 raise TypeError("null option must be a string")
     868                raise TypeError("The null option must be a string")
    729869            options.append('null %s')
    730870            params.append(null)
     
    736876        else:
    737877            if not isinstance(decode, (int, bool)):
    738                 raise TypeError("decode option must be a boolean")
     878                raise TypeError("The decode option must be a boolean")
    739879            if decode and binary_format:
    740                 raise ValueError("decode is not allowed with binary format")
     880                raise ValueError(
     881                    "The decode option is not allowed with binary format")
    741882        if columns:
    742883            if not isinstance(columns, basestring):
     
    788929    def nextset():
    789930        """Not supported."""
    790         raise NotSupportedError("nextset() is not supported")
     931        raise NotSupportedError("The nextset() method is not supported")
    791932
    792933    @staticmethod
     
    8731014            self._cnx.source()
    8741015        except Exception:
    875             raise _op_error("invalid connection")
     1016            raise _op_error("Invalid connection")
    8761017
    8771018    def __enter__(self):
     
    9031044            self._cnx = None
    9041045        else:
    905             raise _op_error("connection has been closed")
     1046            raise _op_error("Connection has been closed")
    9061047
    9071048    def commit(self):
     
    9151056                    raise
    9161057                except Exception:
    917                     raise _op_error("can't commit")
    918         else:
    919             raise _op_error("connection has been closed")
     1058                    raise _op_error("Can't commit")
     1059        else:
     1060            raise _op_error("Connection has been closed")
    9201061
    9211062    def rollback(self):
     
    9291070                    raise
    9301071                except Exception:
    931                     raise _op_error("can't rollback")
    932         else:
    933             raise _op_error("connection has been closed")
     1072                    raise _op_error("Can't rollback")
     1073        else:
     1074            raise _op_error("Connection has been closed")
    9341075
    9351076    def cursor(self):
     
    9391080                return self.cursor_type(self)
    9401081            except Exception:
    941                 raise _op_error("invalid connection")
    942         else:
    943             raise _op_error("connection has been closed")
     1082                raise _op_error("Invalid connection")
     1083        else:
     1084            raise _op_error("Connection has been closed")
    9441085
    9451086    if shortcutmethods:  # otherwise do not implement and document this
Note: See TracChangeset for help on using the changeset viewer.