Changeset 793 for trunk/pg.py


Ignore:
Timestamp:
Jan 28, 2016, 3:07:41 PM (4 years ago)
Author:
cito
Message:

Improve quoting and typecasting in the pg module

Larger refactoring of the code for adapting and typecasting in the pg module.
Things are now a lot cleaner and clearer.

The _Adapt class is responsible for all adapting of Python objects to their
PostgreSQL equivalents when sending data to the database. The typecasting
from PostgreSQL on output happens in the C module, except for the typecasting
of records which is new and provided by the _CastRecord class.

The classic module also did not work properly when regular type names were
switched on with use_regtypes(True), since the adapting of types relied on
the PyGreSQL type names. This has been solved by adding a new _PgType class
that is essentially the old type name, but augmented with all the necessary
information necessary to adapt types, particularly record types.

All tests in test_classic_dbwrapper now run twice, using opposite settings
for the various configuration settings like use_bool() or use_regtypes(),
in order to make sure that no internal functions rely on default settings.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/pg.py

    r785 r793  
    165165
    166166_simpletype = _SimpleType()
     167
     168
     169class _Adapt:
     170    """Mixin providing methods for adapting records and record elements.
     171
     172    This is used when passing values from one of the higher level DB
     173    methods as parameters for a query.
     174
     175    This class must be mixed in to a connection class, because it needs
     176    connection specific methods such as escape_bytea().
     177    """
     178
     179    _bool_true_values = frozenset('t true 1 y yes on'.split())
     180
     181    _date_literals = frozenset('current_date current_time'
     182        ' current_timestamp localtime localtimestamp'.split())
     183
     184    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
     185    _re_record_quote = regex(r'[(,"\\]')
     186    _re_array_escape = _re_record_escape = regex(r'(["\\])')
     187
     188    @classmethod
     189    def _adapt_bool(cls, v):
     190        """Adapt a boolean parameter."""
     191        if isinstance(v, basestring):
     192            if not v:
     193                return None
     194            v = v.lower() in cls._bool_true_values
     195        return 't' if v else 'f'
     196
     197    @classmethod
     198    def _adapt_date(cls, v):
     199        """Adapt a date parameter."""
     200        if not v:
     201            return None
     202        if isinstance(v, basestring) and v.lower() in cls._date_literals:
     203            return _Literal(v)
     204        return v
     205
     206    @staticmethod
     207    def _adapt_num(v):
     208        """Adapt a numeric parameter."""
     209        if not v and v != 0:
     210            return None
     211        return v
     212
     213    _adapt_int = _adapt_float = _adapt_money = _adapt_num
     214
     215    def _adapt_bytea(self, v):
     216        """Adapt a bytea parameter."""
     217        return self.escape_bytea(v)
     218
     219    def _adapt_json(self, v):
     220        """Adapt a json parameter."""
     221        if not v:
     222            return None
     223        if isinstance(v, basestring):
     224            return v
     225        return self.encode_json(v)
     226
     227    @classmethod
     228    def _adapt_text_array(cls, v):
     229        """Adapt a text type array parameter."""
     230        if isinstance(v, list):
     231            adapt = cls._adapt_text_array
     232            return '{%s}' % ','.join(adapt(v) for v in v)
     233        if v is None:
     234            return 'null'
     235        if not v:
     236            return '""'
     237        v = str(v)
     238        if cls._re_array_quote.search(v):
     239            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
     240        return v
     241
     242    _adapt_date_array = _adapt_text_array
     243
     244    @classmethod
     245    def _adapt_bool_array(cls, v):
     246        """Adapt a boolean array parameter."""
     247        if isinstance(v, list):
     248            adapt = cls._adapt_bool_array
     249            return '{%s}' % ','.join(adapt(v) for v in v)
     250        if v is None:
     251            return 'null'
     252        if isinstance(v, basestring):
     253            if not v:
     254                return 'null'
     255            v = v.lower() in cls._bool_true_values
     256        return 't' if v else 'f'
     257
     258    @classmethod
     259    def _adapt_num_array(cls, v):
     260        """Adapt a numeric array parameter."""
     261        if isinstance(v, list):
     262            adapt = cls._adapt_num_array
     263            return '{%s}' % ','.join(adapt(v) for v in v)
     264        if not v and v != 0:
     265            return 'null'
     266        return str(v)
     267
     268    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
     269            _adapt_num_array
     270
     271    def _adapt_bytea_array(self, v):
     272        """Adapt a bytea array parameter."""
     273        if isinstance(v, list):
     274            return b'{' + b','.join(
     275                self._adapt_bytea_array(v) for v in v) + b'}'
     276        if v is None:
     277            return b'null'
     278        return self.escape_bytea(v).replace(b'\\', b'\\\\')
     279
     280    def _adapt_json_array(self, v):
     281        """Adapt a json array parameter."""
     282        if isinstance(v, list):
     283            adapt = self._adapt_json_array
     284            return '{%s}' % ','.join(adapt(v) for v in v)
     285        if not v:
     286            return 'null'
     287        if not isinstance(v, basestring):
     288            v = self.encode_json(v)
     289        if self._re_array_quote.search(v):
     290            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
     291        return v
     292
     293    def _adapt_record(self, v, typ):
     294        """Adapt a record parameter with given type."""
     295        typ = typ.attnames.values()
     296        if len(typ) != len(v):
     297            raise TypeError('Record parameter %s has wrong size' % v)
     298        return '(%s)' % ','.join(getattr(self,
     299            '_adapt_record_%s' % t.simple)(v) for v, t in zip(v, typ))
     300
     301    @classmethod
     302    def _adapt_record_text(cls, v):
     303        """Adapt a text type record component."""
     304        if v is None:
     305            return ''
     306        if not v:
     307            return '""'
     308        v = str(v)
     309        if cls._re_record_quote.search(v):
     310            v = '"%s"' % cls._re_record_escape.sub(r'\\\1', v)
     311        return v
     312
     313    _adapt_record_date = _adapt_record_text
     314
     315    @classmethod
     316    def _adapt_record_bool(cls, v):
     317        """Adapt a boolean record component."""
     318        if v is None:
     319            return ''
     320        if isinstance(v, basestring):
     321            if not v:
     322                return ''
     323            v = v.lower() in cls._bool_true_values
     324        return 't' if v else 'f'
     325
     326    @staticmethod
     327    def _adapt_record_num(v):
     328        """Adapt a numeric record component."""
     329        if not v and v != 0:
     330            return ''
     331        return str(v)
     332
     333    _adapt_record_int = _adapt_record_float = _adapt_record_money = \
     334        _adapt_record_num
     335
     336    def _adapt_record_bytea(self, v):
     337        if v is None:
     338            return ''
     339        v = self.escape_bytea(v)
     340        if bytes is not str and isinstance(v, bytes):
     341            v = v.decode('ascii')
     342        return v.replace('\\', '\\\\')
     343
     344    def _adapt_record_json(self, v):
     345        """Adapt a bytea record component."""
     346        if not v:
     347            return ''
     348        if not isinstance(v, basestring):
     349            v = self.encode_json(v)
     350        if self._re_array_quote.search(v):
     351            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
     352        return v
     353
     354    def _adapt_param(self, value, typ, params):
     355        """Adapt and add a parameter to the list."""
     356        if isinstance(value, _Literal):
     357            return value
     358        if value is not None:
     359            simple = typ.simple
     360            if simple == 'text':
     361                pass
     362            elif simple == 'record':
     363                if isinstance(value, tuple):
     364                    value = self._adapt_record(value, typ)
     365            elif simple.endswith('[]'):
     366                if isinstance(value, list):
     367                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
     368                    value = adapt(value)
     369            else:
     370                adapt = getattr(self, '_adapt_%s' % simple)
     371                value = adapt(value)
     372                if isinstance(value, _Literal):
     373                    return value
     374        params.append(value)
     375        return '$%d' % len(params)
     376
     377
     378class _CastRecord:
     379    """Class providing methods for casting records and record elements.
     380
     381    This is needed when getting result values from one of the higher level DB
     382    methods, since the lower level query method only casts the other types.
     383    """
     384
     385    @staticmethod
     386    def cast_bool(v):
     387        if not get_bool():
     388            return v
     389        return v[0] == 't'
     390
     391    @staticmethod
     392    def cast_bytea(v):
     393        return unescape_bytea(v)
     394
     395    @staticmethod
     396    def cast_float(v):
     397        return float(v)
     398
     399    @staticmethod
     400    def cast_int(v):
     401        return int(v)
     402
     403    @staticmethod
     404    def cast_json(v):
     405        cast = get_jsondecode()
     406        if not cast:
     407            return v
     408        return cast(v)
     409
     410    @staticmethod
     411    def cast_num(v):
     412        return (get_decimal() or float)(v)
     413
     414    @staticmethod
     415    def cast_money(v):
     416        point = get_decimal_point()
     417        if not point:
     418            return v
     419        if point != '.':
     420            v = v.replace(point, '.')
     421        v = v.replace('(', '-')
     422        v = ''.join(c for c in v if c.isdigit() or c in '.-')
     423        return (get_decimal() or float)(v)
     424
     425    @classmethod
     426    def cast(cls, v, typ):
     427        types = typ.attnames.values()
     428        cast = [getattr(cls, 'cast_%s' % t.simple, None) for t in types]
     429        v = cast_record(v, cast)
     430        return typ.namedtuple(*v)
     431
     432
     433class _PgType(str):
     434    """Class augmenting the simple type name with additional info."""
     435
     436    _num_types = frozenset('int float num money'
     437        ' int2 int4 int8 float4 float8 numeric money'.split())
     438
     439    @classmethod
     440    def create(cls, db, pgtype, regtype, typrelid):
     441        """Create a PostgreSQL type name with additional info."""
     442        simple = 'record' if typrelid else _simpletype[pgtype]
     443        self = cls(regtype if db._regtypes else simple)
     444        self.db = db
     445        self.simple = simple
     446        self.pgtype = pgtype
     447        self.regtype = regtype
     448        self.typrelid = typrelid
     449        self._attnames = self._namedtuple = None
     450        return self
     451
     452    @property
     453    def attnames(self):
     454        """Get names and types of the fields of a composite type."""
     455        if not self.typrelid:
     456            return None
     457        if not self._attnames:
     458            self._attnames = self.db.get_attnames(self.typrelid)
     459        return self._attnames
     460
     461    @property
     462    def namedtuple(self):
     463        """Return named tuple class representing a composite type."""
     464        if not self._namedtuple:
     465            self._namedtuple = namedtuple(self, self.attnames)
     466        return self._namedtuple
     467
     468    def cast(self, value):
     469        if value is not None and self.typrelid:
     470            value = _CastRecord.cast(value, self)
     471        return value
    167472
    168473
     
    350655# The actual PostGreSQL database connection interface:
    351656
    352 class DB(object):
     657class DB(_Adapt):
    353658    """Wrapper class for the _pg connection type."""
    354659
     
    452757        return bool(d) if get_bool() else ('t' if d else 'f')
    453758
    454     _bool_true_values = frozenset('t true 1 y yes on'.split())
    455 
    456     def _prepare_bool(self, d):
    457         """Prepare a boolean parameter."""
    458         if isinstance(d, basestring):
    459             if not d:
    460                 return None
    461             d = d.lower() in self._bool_true_values
    462         return 't' if d else 'f'
    463 
    464     _date_literals = frozenset('current_date current_time'
    465         ' current_timestamp localtime localtimestamp'.split())
    466 
    467     def _prepare_date(self, d):
    468         """Prepare a date parameter."""
    469         if not d:
    470             return None
    471         if isinstance(d, basestring) and d.lower() in self._date_literals:
    472             return _Literal(d)
    473         return d
    474 
    475     _num_types = frozenset('int float num money'
    476         ' int2 int4 int8 float4 float8 numeric money'.split())
    477 
    478     @staticmethod
    479     def _prepare_num(d):
    480         """Prepare a numeric parameter."""
    481         if not d and d != 0:
    482             return None
    483         return d
    484 
    485     _prepare_int = _prepare_float = _prepare_money = _prepare_num
    486 
    487     def _prepare_bytea(self, d):
    488         """Prepare a bytea parameter."""
    489         return self.escape_bytea(d)
    490 
    491     def _prepare_json(self, d):
    492         """Prepare a json parameter."""
    493         if not d:
    494             return None
    495         if isinstance(d, basestring):
    496             return d
    497         return self.encode_json(d)
    498 
    499     _re_array_escape = regex(r'(["\\])')
    500     _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
    501 
    502     def _prepare_bool_array(self, d):
    503         """Prepare a bool array parameter."""
    504         if isinstance(d, list):
    505             return '{%s}' % ','.join(self._prepare_bool_array(v) for v in d)
    506         if d is None:
    507             return 'null'
    508         if isinstance(d, basestring):
    509             if not d:
    510                 return 'null'
    511             d = d.lower() in self._bool_true_values
    512         return 't' if d else 'f'
    513 
    514     def _prepare_num_array(self, d):
    515         """Prepare a numeric array parameter."""
    516         if isinstance(d, list):
    517             return '{%s}' % ','.join(self._prepare_num_array(v) for v in d)
    518         if not d and d != 0:
    519             return 'null'
    520         return str(d)
    521 
    522     _prepare_int_array = _prepare_float_array = _prepare_money_array = \
    523             _prepare_num_array
    524 
    525     def _prepare_text_array(self, d):
    526         """Prepare a text array parameter."""
    527         if isinstance(d, list):
    528             return '{%s}' % ','.join(self._prepare_text_array(v) for v in d)
    529         if d is None:
    530             return 'null'
    531         if not d:
    532             return '""'
    533         d = str(d)
    534         if self._re_array_quote.search(d):
    535             d = '"%s"' % self._re_array_escape.sub(r'\\\1', d)
    536         return d
    537 
    538     def _prepare_bytea_array(self, d):
    539         """Prepare a bytea array parameter."""
    540         if isinstance(d, list):
    541             return b'{' + b','.join(
    542                 self._prepare_bytea_array(v) for v in d) + b'}'
    543         if d is None:
    544             return b'null'
    545         return self.escape_bytea(d).replace(b'\\', b'\\\\')
    546 
    547     def _prepare_json_array(self, d):
    548         """Prepare a json array parameter."""
    549         if isinstance(d, list):
    550             return '{%s}' % ','.join(self._prepare_json_array(v) for v in d)
    551         if not d:
    552             return 'null'
    553         if not isinstance(d, basestring):
    554             d = self.encode_json(d)
    555         if self._re_array_quote.search(d):
    556             d = '"%s"' % self._re_array_escape.sub(r'\\\1', d)
    557         return d
    558 
    559     def _prepare_param(self, value, typ, params):
    560         """Prepare and add a parameter to the list."""
    561         if isinstance(value, _Literal):
    562             return value
    563         if value is not None and typ != 'text':
    564             if typ.endswith('[]'):
    565                 if isinstance(value, list):
    566                     prepare = getattr(self, '_prepare_%s_array' % typ[:-2])
    567                     value = prepare(value)
    568                 elif isinstance(value, basestring):
    569                     value = value.strip()
    570                     if not value.startswith('{') or not value.endswith('}'):
    571                         if value[:5].lower() == 'array':
    572                             value = value[5:].lstrip()
    573                         if value.startswith('[') and value.endswith(']'):
    574                             value = _Literal('ARRAY%s' % value)
    575                         else:
    576                             raise ValueError(
    577                                 'Invalid array expression: %s' % value)
    578                 else:
    579                     raise ValueError('Invalid array parameter: %s' % value)
    580             else:
    581                 prepare = getattr(self, '_prepare_%s' % typ)
    582                 value = prepare(value)
    583             if isinstance(value, _Literal):
    584                 return value
    585         params.append(value)
    586         return '$%d' % len(params)
    587 
    588759    def _list_params(self, params):
    589760        """Create a human readable parameter list."""
     
    591762
    592763    @staticmethod
    593     def _prepare_qualified_param(name, param):
     764    def _adapt_qualified_param(name, param):
    594765        """Quote parameter representing a qualified name.
    595766
     
    602773        if isinstance(param, int):
    603774            param = "$%d" % param
    604         if '.' not in name:
     775        if isinstance(name, basestring) and '.' not in name:
    605776            param = 'quote_ident(%s)' % (param,)
    606777        return param
     
    8701041                " WHERE i.indrelid=%s::regclass"
    8711042                " AND i.indisprimary ORDER BY a.attnum") % (
    872                     self._prepare_qualified_param(table, 1),)
     1043                    self._adapt_qualified_param(table, 1),)
    8731044            pkey = self.db.query(q, (table,)).getresult()
    8741045            if not pkey:
     
    9341105            names = attnames[table]
    9351106        except KeyError:  # cache miss, check the database
    936             q = ("SELECT a.attname, t.typname%s"
     1107            q = ("SELECT a.attname, t.typname, t.typname::regtype, t.typrelid"
    9371108                " FROM pg_attribute a"
    9381109                " JOIN pg_type t ON t.oid = a.atttypid"
     
    9401111                " AND (a.attnum > 0 OR a.attname = 'oid')"
    9411112                " AND NOT a.attisdropped ORDER BY a.attnum") % (
    942                     '::regtype' if self._regtypes else '',
    943                     self._prepare_qualified_param(table, 1))
     1113                    self._adapt_qualified_param(table, 1))
    9441114            names = self.db.query(q, (table,)).getresult()
    945             if not self._regtypes:
    946                 names = ((name, _simpletype[typ]) for name, typ in names)
     1115            names = ((name, _PgType.create(self, pgtype, regtype, typrelid))
     1116                for name, pgtype, regtype, typrelid in names)
    9471117            names = AttrDict(names)
    9481118            attnames[table] = names  # cache it
     
    9671137        except KeyError:  # cache miss, ask the database
    9681138            q = "SELECT has_table_privilege(%s, $2)" % (
    969                 self._prepare_qualified_param(table, 1),)
     1139                self._adapt_qualified_param(table, 1),)
    9701140            q = self.db.query(q, (table, privilege))
    9711141            ret = q.getresult()[0][0] == self._make_bool(True)
     
    10251195            row = dict(zip(keyname, row))
    10261196        params = []
    1027         param = partial(self._prepare_param, params=params)
     1197        param = partial(self._adapt_param, params=params)
    10281198        col = self.escape_identifier
    10291199        what = 'oid, *' if qoid else '*'
     
    10451215            if qoid and n == 'oid':
    10461216                n = qoid
     1217            else:
     1218                value = attnames[n].cast(value)
    10471219            row[n] = value
    10481220        return row
     
    10711243        qoid = _oid_key(table) if 'oid' in attnames else None
    10721244        params = []
    1073         param = partial(self._prepare_param, params=params)
     1245        param = partial(self._adapt_param, params=params)
    10741246        col = self.escape_identifier
    10751247        names, values = [], []
     
    10911263                if qoid and n == 'oid':
    10921264                    n = qoid
     1265                else:
     1266                    value = attnames[n].cast(value)
    10931267                row[n] = value
    10941268        return row
     
    11321306                    raise KeyError('Missing primary key in row')
    11331307        params = []
    1134         param = partial(self._prepare_param, params=params)
     1308        param = partial(self._adapt_param, params=params)
    11351309        col = self.escape_identifier
    11361310        where = ' AND '.join('%s = %s' % (
     
    11581332                if qoid and n == 'oid':
    11591333                    n = qoid
     1334                else:
     1335                    value = attnames[n].cast(value)
    11601336                row[n] = value
    11611337        return row
     
    12151391        qoid = _oid_key(table) if 'oid' in attnames else None
    12161392        params = []
    1217         param = partial(self._prepare_param,params=params)
     1393        param = partial(self._adapt_param,params=params)
    12181394        col = self.escape_identifier
    12191395        names, values, updates = [], [], []
     
    12591435                if qoid and n == 'oid':
    12601436                    n = qoid
     1437                else:
     1438                    value = attnames[n].cast(value)
    12611439                row[n] = value
    12621440        else:
     
    12791457            if n == 'oid':
    12801458                continue
    1281             if t in self._num_types:
     1459            t = t.simple
     1460            if t in _PgType._num_types:
    12821461                row[n] = 0
    12831462            elif t == 'bool':
     
    13281507                    raise KeyError('Missing primary key in row')
    13291508        params = []
    1330         param = partial(self._prepare_param, params=params)
     1509        param = partial(self._adapt_param, params=params)
    13311510        col = self.escape_identifier
    13321511        where = ' AND '.join('%s = %s' % (
Note: See TracChangeset for help on using the changeset viewer.