Changeset 798 for trunk/pg.py


Ignore:
Timestamp:
Jan 30, 2016, 2:55:18 PM (4 years ago)
Author:
cito
Message:

Port type cache and typecasting from pgdb to pg

So far, the typecasting in the classic module was been only done by
the C extension module and was not extensible through typecasting
functions in Python. This has now been made extensible by adding
a cast hook to the C extension module which has been hooked up to
a new type cache object that holds information on the types and the
associated typecast functions. All of this works very similar to the
pgdb module now, except that the basic types are still handled by
the C extension module and the Python typecast functions are only
called via the hook for types which are not supported internally.

Also added tests and a chapter on the type cache in the documentation,
and cleaned up the error messages in the C extension module.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/pg.py

    r793 r798  
    4141from re import compile as regex
    4242from json import loads as jsondecode, dumps as jsonencode
     43
     44try:
     45    long
     46except NameError:  # Python >= 3.0
     47    long = int
    4348
    4449try:
     
    376381
    377382
    378 class _CastRecord:
    379     """Class providing methods for casting records and record elements.
    380 
    381     This is needed when getting result values from one of the higher level DB
    382     methods, since the lower level query method only casts the other types.
     383def cast_bool(value):
     384    """Cast a boolean value."""
     385    if not get_bool():
     386        return value
     387    return value[0] == 't'
     388
     389
     390def cast_json(value):
     391    """Cast a JSON value."""
     392    cast = get_jsondecode()
     393    if not cast:
     394        return value
     395    return cast(value)
     396
     397
     398def cast_num(value):
     399    """Cast a numeric value."""
     400    return (get_decimal() or float)(value)
     401
     402
     403def cast_money(value):
     404    """Cast a money value."""
     405    point = get_decimal_point()
     406    if not point:
     407        return value
     408    if point != '.':
     409        value = value.replace(point, '.')
     410    value = value.replace('(', '-')
     411    value = ''.join(c for c in value if c.isdigit() or c in '.-')
     412    return (get_decimal() or float)(value)
     413
     414
     415def cast_int2vector(value):
     416    """Cast an int2vector value."""
     417    return [int(v) for v in value.split()]
     418
     419
     420class Typecasts(dict):
     421    """Dictionary mapping database types to typecast functions.
     422
     423    The cast functions get passed the string representation of a value in
     424    the database which they need to convert to a Python object.  The
     425    passed string will never be None since NULL values are already be
     426    handled before the cast function is called.
     427
     428    Note that the basic types are already handled by the C extension.
     429    They only need to be handled here as record or array components.
    383430    """
    384431
    385     @staticmethod
    386     def cast_bool(v):
    387         if not get_bool():
    388             return v
    389         return v[0] == 't'
    390 
    391     @staticmethod
    392     def cast_bytea(v):
    393         return unescape_bytea(v)
    394 
    395     @staticmethod
    396     def cast_float(v):
    397         return float(v)
    398 
    399     @staticmethod
    400     def cast_int(v):
    401         return int(v)
    402 
    403     @staticmethod
    404     def cast_json(v):
    405         cast = get_jsondecode()
    406         if not cast:
    407             return v
    408         return cast(v)
    409 
    410     @staticmethod
    411     def cast_num(v):
    412         return (get_decimal() or float)(v)
    413 
    414     @staticmethod
    415     def cast_money(v):
    416         point = get_decimal_point()
    417         if not point:
    418             return v
    419         if point != '.':
    420             v = v.replace(point, '.')
    421         v = v.replace('(', '-')
    422         v = ''.join(c for c in v if c.isdigit() or c in '.-')
    423         return (get_decimal() or float)(v)
     432    # the default cast functions
     433    # (str functions are ignored but have been added for faster access)
     434    defaults = {'char': str, 'bpchar': str, 'name': str,
     435        'text': str, 'varchar': str,
     436        'bool': cast_bool, 'bytea': unescape_bytea,
     437        'int2': int, 'int4': int, 'serial': int,
     438        'int8': long, 'json': cast_json, 'jsonb': cast_json,
     439        'oid': long, 'oid8': long,
     440        'float4': float, 'float8': float,
     441        'numeric': cast_num, 'money': cast_money,
     442        'int2vector': cast_int2vector,
     443        'anyarray': cast_array, 'record': cast_record}
     444
     445    def __missing__(self, typ):
     446        """Create a cast function if it is not cached.
     447       
     448        Note that this class never raises a KeyError,
     449        but returns None when no special cast function exists.
     450        """
     451        if not isinstance(typ, str):
     452            raise TypeError('Invalid type: %s' % typ)
     453        cast = self.defaults.get(typ)
     454        if cast:
     455            # store default for faster access
     456            self[typ] = cast
     457        elif typ.startswith('_'):
     458            base_cast = self[typ[1:]]
     459            cast = self.create_array_cast(base_cast)
     460            if base_cast:
     461                self[typ] = cast
     462        else:
     463            attnames = self.get_attnames(typ)
     464            if attnames:
     465                casts = [self[v.pgtype] for v in attnames.values()]
     466                cast = self.create_record_cast(typ, attnames, casts)
     467                self[typ] = cast
     468        return cast
     469
     470    def get(self, typ, default=None):
     471        """Get the typecast function for the given database type."""
     472        return self[typ] or default
     473
     474    def set(self, typ, cast):
     475        """Set a typecast function for the specified database type(s)."""
     476        if isinstance(typ, basestring):
     477            typ = [typ]
     478        if cast is None:
     479            for t in typ:
     480                self.pop(t, None)
     481                self.pop('_%s' % t, None)
     482        else:
     483            if not callable(cast):
     484                raise TypeError("Cast parameter must be callable")
     485            for t in typ:
     486                self[t] = cast
     487                self.pop('_%s % t', None)
     488
     489    def reset(self, typ=None):
     490        """Reset the typecasts for the specified type(s) to their defaults.
     491
     492        When no type is specified, all typecasts will be reset.
     493        """
     494        defaults = self.defaults
     495        if typ is None:
     496            self.clear()
     497            self.update(defaults)
     498        else:
     499            if isinstance(typ, basestring):
     500                typ = [typ]
     501            for t in typ:
     502                self.set(t, defaults.get(t))
    424503
    425504    @classmethod
    426     def cast(cls, v, typ):
    427         types = typ.attnames.values()
    428         cast = [getattr(cls, 'cast_%s' % t.simple, None) for t in types]
    429         v = cast_record(v, cast)
    430         return typ.namedtuple(*v)
    431 
    432 
    433 class _PgType(str):
    434     """Class augmenting the simple type name with additional info."""
    435 
    436     _num_types = frozenset('int float num money'
    437         ' int2 int4 int8 float4 float8 numeric money'.split())
     505    def get_default(cls, typ):
     506        """Get the default typecast function for the given database type."""
     507        return cls.defaults.get(typ)
    438508
    439509    @classmethod
    440     def create(cls, db, pgtype, regtype, typrelid):
    441         """Create a PostgreSQL type name with additional info."""
    442         simple = 'record' if typrelid else _simpletype[pgtype]
    443         self = cls(regtype if db._regtypes else simple)
    444         self.db = db
    445         self.simple = simple
    446         self.pgtype = pgtype
    447         self.regtype = regtype
    448         self.typrelid = typrelid
    449         self._attnames = self._namedtuple = None
    450         return self
     510    def set_default(cls, typ, cast):
     511        """Set a default typecast function for the given database type(s)."""
     512        if isinstance(typ, basestring):
     513            typ = [typ]
     514        defaults = cls.defaults
     515        if cast is None:
     516            for t in typ:
     517                defaults.pop(t, None)
     518                defaults.pop('_%s' % t, None)
     519        else:
     520            if not callable(cast):
     521                raise TypeError("Cast parameter must be callable")
     522            for t in typ:
     523                defaults[t] = cast
     524                defaults.pop('_%s % t', None)
     525
     526    def get_attnames(self, typ):
     527        """Return the fields for the given record type.
     528
     529        This method will be replaced with the get_attnames() method of DbTypes.
     530        """
     531        return {}
     532
     533    def create_array_cast(self, cast):
     534        """Create an array typecast for the given base cast."""
     535        return lambda v: cast_array(v, cast)
     536
     537    def create_record_cast(self, name, fields, casts):
     538        """Create a named record typecast for the given fields and casts."""
     539        record = namedtuple(name, fields)
     540        return lambda v: record(*cast_record(v, casts))
     541
     542
     543def get_typecast(typ):
     544    """Get the global typecast function for the given database type(s)."""
     545    return Typecasts.get_default(typ)
     546
     547
     548def set_typecast(typ, cast):
     549    """Set a global typecast function for the given database type(s).
     550
     551    Note that connections cache cast functions. To be sure a global change
     552    is picked up by a running connection, call db.db_types.reset_typecast().
     553    """
     554    Typecasts.set_default(typ, cast)
     555
     556
     557class DbType(str):
     558    """Class augmenting the simple type name with additional info.
     559
     560    The following additional information is provided:
     561
     562        oid: the PostgreSQL type OID
     563        pgtype: the PostgreSQL type name
     564        regtype: the regular type name
     565        simple: the simple PyGreSQL type name
     566        typtype: b = base type, c = composite type etc.
     567        category: A = Array, b =Boolean, C = Composite etc.
     568        delim: delimiter for array types
     569        relid: corresponding table for composite types
     570        attnames: attributes for composite types
     571    """
    451572
    452573    @property
    453574    def attnames(self):
    454575        """Get names and types of the fields of a composite type."""
    455         if not self.typrelid:
     576        return self._get_attnames(self)
     577
     578
     579class DbTypes(dict):
     580    """Cache for PostgreSQL data types.
     581
     582    This cache maps type OIDs and names to DbType objects containing
     583    information on the associated database type.
     584    """
     585
     586    _num_types = frozenset('int float num money'
     587        ' int2 int4 int8 float4 float8 numeric money'.split())
     588
     589    def __init__(self, db):
     590        """Initialize type cache for connection."""
     591        super(DbTypes, self).__init__()
     592        self._get_attnames = db.get_attnames
     593        db = db.db
     594        self.query = db.query
     595        self.escape_string = db.escape_string
     596        self._typecasts = Typecasts()
     597        self._typecasts.get_attnames = self.get_attnames
     598        self._regtypes = False
     599
     600    def add(self, oid, pgtype, regtype,
     601               typtype, category, delim, relid):
     602        """Create a PostgreSQL type name with additional info."""
     603        if oid in self:
     604            return self[oid]
     605        simple = 'record' if relid else _simpletype[pgtype]
     606        typ = DbType(regtype if self._regtypes else simple)
     607        typ.oid = oid
     608        typ.simple = simple
     609        typ.pgtype = pgtype
     610        typ.regtype = regtype
     611        typ.typtype = typtype
     612        typ.category = category
     613        typ.delim = delim
     614        typ.relid = relid
     615        typ._get_attnames = self.get_attnames
     616        return typ
     617
     618    def __missing__(self, key):
     619        """Get the type info from the database if it is not cached."""
     620        try:
     621            res = self.query("SELECT oid, typname, typname::regtype,"
     622                " typtype, typcategory, typdelim, typrelid"
     623                " FROM pg_type WHERE oid=%s::regtype" %
     624                (DB._adapt_qualified_param(key, 1),), (key,)).getresult()
     625        except ProgrammingError:
     626            res = None
     627        if not res:
     628            raise KeyError('Type %s could not be found' % key)
     629        res = res[0]
     630        typ = self.add(*res)
     631        self[typ.oid] = self[typ.pgtype] = typ
     632        return typ
     633
     634    def get(self, key, default=None):
     635        """Get the type even if it is not cached."""
     636        try:
     637            return self[key]
     638        except KeyError:
     639            return default
     640
     641    def get_attnames(self, typ):
     642        """Get names and types of the fields of a composite type."""
     643        if not isinstance(typ, DbType):
     644            typ = self.get(typ)
     645            if not typ:
     646                return None
     647        if not typ.relid:
    456648            return None
    457         if not self._attnames:
    458             self._attnames = self.db.get_attnames(self.typrelid)
    459         return self._attnames
    460 
    461     @property
    462     def namedtuple(self):
    463         """Return named tuple class representing a composite type."""
    464         if not self._namedtuple:
    465             self._namedtuple = namedtuple(self, self.attnames)
    466         return self._namedtuple
    467 
    468     def cast(self, value):
    469         if value is not None and self.typrelid:
    470             value = _CastRecord.cast(value, self)
    471         return value
     649        return self._get_attnames(typ.relid, with_oid=False)
     650
     651    def get_typecast(self, typ):
     652        """Get the typecast function for the given database type."""
     653        return self._typecasts.get(typ)
     654
     655    def set_typecast(self, typ, cast):
     656        """Set a typecast function for the specified database type(s)."""
     657        self._typecasts.set(typ, cast)
     658
     659    def reset_typecast(self, typ=None):
     660        """Reset the typecast function for the specified database type(s)."""
     661        self._typecasts.reset(typ)
     662
     663    def typecast(self, value, typ):
     664        """Cast the given value according to the given database type."""
     665        if value is None:
     666            # for NULL values, no typecast is necessary
     667            return None
     668        if not isinstance(typ, DbType):
     669            typ = self.get(typ)
     670            if typ:
     671                typ = typ.pgtype
     672        cast = self.get_typecast(typ) if typ else None
     673        if not cast or cast is str:
     674            # no typecast is necessary
     675            return value
     676        return cast(value)
    472677
    473678
     
    691896        self._privileges = {}
    692897        self._args = args, kw
     898        self.dbtypes = DbTypes(self)
     899        db.set_cast_hook(self.dbtypes.typecast)
    693900        self.debug = None  # For debugging scripts, this can be set
    694901            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
     
    10481255            # not the order as defined by the columns in the table
    10491256            if len(pkey) > 1:
    1050                 indkey = [int(k) for k in pkey[0][2].split()]
     1257                indkey = pkey[0][2]
    10511258                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
    10521259                pkey = tuple(row[0] for row in pkey)
     
    10841291        return self.get_relations('r')
    10851292
    1086     def get_attnames(self, table, flush=False):
     1293    def get_attnames(self, table, with_oid=True, flush=False):
    10871294        """Given the name of a table, dig out the set of attribute names.
    10881295
     
    11051312            names = attnames[table]
    11061313        except KeyError:  # cache miss, check the database
    1107             q = ("SELECT a.attname, t.typname, t.typname::regtype, t.typrelid"
     1314            q = "a.attnum > 0"
     1315            if with_oid:
     1316                q = "(%s OR a.attname = 'oid')" % q
     1317            q = ("SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
     1318                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
    11081319                " FROM pg_attribute a"
    11091320                " JOIN pg_type t ON t.oid = a.atttypid"
    1110                 " WHERE a.attrelid = %s::regclass"
    1111                 " AND (a.attnum > 0 OR a.attname = 'oid')"
     1321                " WHERE a.attrelid = %s::regclass AND %s"
    11121322                " AND NOT a.attisdropped ORDER BY a.attnum") % (
    1113                     self._adapt_qualified_param(table, 1))
     1323                    self._adapt_qualified_param(table, 1), q)
    11141324            names = self.db.query(q, (table,)).getresult()
    1115             names = ((name, _PgType.create(self, pgtype, regtype, typrelid))
    1116                 for name, pgtype, regtype, typrelid in names)
     1325            types = self.dbtypes
     1326            names = ((name[0], types.add(*name[1:])) for name in names)
    11171327            names = AttrDict(names)
    11181328            attnames[table] = names  # cache it
     
    11221332        """Use regular type names instead of simplified type names."""
    11231333        if regtypes is None:
    1124             return self._regtypes
     1334            return self.dbtypes._regtypes
    11251335        else:
    11261336            regtypes = bool(regtypes)
    1127             if regtypes != self._regtypes:
    1128                 self._regtypes = regtypes
     1337            if regtypes != self.dbtypes._regtypes:
     1338                self.dbtypes._regtypes = regtypes
    11291339                self._attnames.clear()
     1340                self.dbtypes.clear()
    11301341            return regtypes
    11311342
     
    12151426            if qoid and n == 'oid':
    12161427                n = qoid
    1217             else:
    1218                 value = attnames[n].cast(value)
    12191428            row[n] = value
    12201429        return row
     
    12631472                if qoid and n == 'oid':
    12641473                    n = qoid
    1265                 else:
    1266                     value = attnames[n].cast(value)
    12671474                row[n] = value
    12681475        return row
     
    13321539                if qoid and n == 'oid':
    13331540                    n = qoid
    1334                 else:
    1335                     value = attnames[n].cast(value)
    13361541                row[n] = value
    13371542        return row
     
    14351640                if qoid and n == 'oid':
    14361641                    n = qoid
    1437                 else:
    1438                     value = attnames[n].cast(value)
    14391642                row[n] = value
    14401643        else:
     
    14581661                continue
    14591662            t = t.simple
    1460             if t in _PgType._num_types:
     1663            if t in DbTypes._num_types:
    14611664                row[n] = 0
    14621665            elif t == 'bool':
Note: See TracChangeset for help on using the changeset viewer.