Changeset 799 for trunk/pg.py


Ignore:
Timestamp:
Jan 31, 2016, 1:45:58 PM (3 years ago)
Author:
cito
Message:

Improve adaptation and add query_formatted() method

Also added more tests and documentation.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/pg.py

    r798 r799  
    3535import warnings
    3636
     37from datetime import date, time, datetime, timedelta
    3738from decimal import Decimal
     39from math import isnan, isinf
    3840from collections import namedtuple
    39 from functools import partial
    4041from operator import itemgetter
    4142from re import compile as regex
     
    145146
    146147
    147 class _SimpleType(dict):
     148class _SimpleTypes(dict):
    148149    """Dictionary mapping pg_type names to simple type names."""
    149150
     
    169170        return 'text'
    170171
    171 _simpletype = _SimpleType()
    172 
    173 
    174 class _Adapt:
    175     """Mixin providing methods for adapting records and record elements.
    176 
    177     This is used when passing values from one of the higher level DB
    178     methods as parameters for a query.
    179 
    180     This class must be mixed in to a connection class, because it needs
    181     connection specific methods such as escape_bytea().
     172_simpletypes = _SimpleTypes()
     173
     174
     175def _quote_if_unqualified(param, name):
     176    """Quote parameter representing a qualified name.
     177
     178    Puts a quote_ident() call around the give parameter unless
     179    the name contains a dot, in which case the name is ambiguous
     180    (could be a qualified name or just a name with a dot in it)
     181    and must be quoted manually by the caller.
    182182    """
     183    if isinstance(name, basestring) and '.' not in name:
     184        return 'quote_ident(%s)' % (param,)
     185    return param
     186
     187
     188class _ParameterList(list):
     189    """Helper class for building typed parameter lists."""
     190
     191    def add(self, value, typ=None):
     192        """Typecast value with known database type and build parameter list.
     193
     194        If this is a literal value, it will be returned as is.  Otherwise, a
     195        placeholder will be returned and the parameter list will be augmented.
     196        """
     197        value = self.adapt(value, typ)
     198        if isinstance(value, Literal):
     199            return value
     200        self.append(value)
     201        return '$%d' % len(self)
     202
     203
     204class Literal(str):
     205    """Wrapper class for marking literal SQL values."""
     206
     207
     208class Json:
     209    """Wrapper class for marking Json values."""
     210
     211    def __init__(self, obj):
     212        self.obj = obj
     213
     214
     215class Bytea(bytes):
     216    """Wrapper class for marking Bytea values."""
     217
     218
     219class Adapter:
     220    """Class providing methods for adapting parameters to the database."""
    183221
    184222    _bool_true_values = frozenset('t true 1 y yes on'.split())
     
    190228    _re_record_quote = regex(r'[(,"\\]')
    191229    _re_array_escape = _re_record_escape = regex(r'(["\\])')
     230
     231    def __init__(self, db):
     232        self.db = db
     233        self.encode_json = db.encode_json
     234        db = db.db
     235        self.escape_bytea = db.escape_bytea
     236        self.escape_string = db.escape_string
    192237
    193238    @classmethod
     
    206251            return None
    207252        if isinstance(v, basestring) and v.lower() in cls._date_literals:
    208             return _Literal(v)
     253            return Literal(v)
    209254        return v
    210255
     
    298343    def _adapt_record(self, v, typ):
    299344        """Adapt a record parameter with given type."""
    300         typ = typ.attnames.values()
     345        typ = self.get_attnames(typ).values()
    301346        if len(typ) != len(v):
    302347            raise TypeError('Record parameter %s has wrong size' % v)
    303         return '(%s)' % ','.join(getattr(self,
    304             '_adapt_record_%s' % t.simple)(v) for v, t in zip(v, typ))
    305 
    306     @classmethod
    307     def _adapt_record_text(cls, v):
    308         """Adapt a text type record component."""
    309         if v is None:
    310             return ''
    311         if not v:
    312             return '""'
    313         v = str(v)
    314         if cls._re_record_quote.search(v):
    315             v = '"%s"' % cls._re_record_escape.sub(r'\\\1', v)
    316         return v
    317 
    318     _adapt_record_date = _adapt_record_text
    319 
    320     @classmethod
    321     def _adapt_record_bool(cls, v):
    322         """Adapt a boolean record component."""
    323         if v is None:
    324             return ''
    325         if isinstance(v, basestring):
    326             if not v:
    327                 return ''
    328             v = v.lower() in cls._bool_true_values
    329         return 't' if v else 'f'
    330 
    331     @staticmethod
    332     def _adapt_record_num(v):
    333         """Adapt a numeric record component."""
    334         if not v and v != 0:
    335             return ''
    336         return str(v)
    337 
    338     _adapt_record_int = _adapt_record_float = _adapt_record_money = \
    339         _adapt_record_num
    340 
    341     def _adapt_record_bytea(self, v):
    342         if v is None:
    343             return ''
    344         v = self.escape_bytea(v)
    345         if bytes is not str and isinstance(v, bytes):
    346             v = v.decode('ascii')
    347         return v.replace('\\', '\\\\')
    348 
    349     def _adapt_record_json(self, v):
    350         """Adapt a bytea record component."""
    351         if not v:
    352             return ''
    353         if not isinstance(v, basestring):
    354             v = self.encode_json(v)
    355         if self._re_array_quote.search(v):
    356             v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
    357         return v
    358 
    359     def _adapt_param(self, value, typ, params):
    360         """Adapt and add a parameter to the list."""
    361         if isinstance(value, _Literal):
    362             return value
    363         if value is not None:
    364             simple = typ.simple
     348        adapt = self.adapt
     349        value = []
     350        for v, t in zip(v, typ):
     351            v = adapt(v, t)
     352            if v is None:
     353                v = ''
     354            elif not v:
     355                v = '""'
     356            else:
     357                if isinstance(v, bytes):
     358                    if str is not bytes:
     359                        v = v.decode('ascii')
     360                else:
     361                    v = str(v)
     362                if self._re_record_quote.search(v):
     363                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
     364            value.append(v)
     365        return '(%s)' % ','.join(value)
     366
     367    def adapt(self, value, typ=None):
     368        """Adapt a value with known database type."""
     369        if value is not None and not isinstance(value, Literal):
     370            if typ:
     371                simple = self.get_simple_name(typ)
     372            else:
     373                typ = simple = self.guess_simple_type(value) or 'text'
     374            try:
     375                value = value.__pg_str__(typ)
     376            except AttributeError:
     377                pass
    365378            if simple == 'text':
    366379                pass
     
    375388                adapt = getattr(self, '_adapt_%s' % simple)
    376389                value = adapt(value)
    377                 if isinstance(value, _Literal):
    378                     return value
    379         params.append(value)
    380         return '$%d' % len(params)
     390        return value
     391
     392    @staticmethod
     393    def simple_type(name):
     394        """Create a simple database type with given attribute names."""
     395        typ = DbType(name)
     396        typ.simple = name
     397        return typ
     398
     399    @staticmethod
     400    def get_simple_name(typ):
     401        """Get the simple name of a database type."""
     402        if isinstance(typ, DbType):
     403            return typ.simple
     404        return _simpletypes[typ]
     405
     406    @staticmethod
     407    def get_attnames(typ):
     408        """Get the attribute names of a composite database type."""
     409        if isinstance(typ, DbType):
     410            return typ.attnames
     411        return {}
     412
     413    @classmethod
     414    def guess_simple_type(cls, value):
     415        """Try to guess which database type the given value has."""
     416        if isinstance(value, Bytea):
     417            return 'bytea'
     418        if isinstance(value, basestring):
     419            return 'text'
     420        if isinstance(value, bool):
     421            return 'bool'
     422        if isinstance(value, (int, long)):
     423            return 'int'
     424        if isinstance(value, float):
     425            return 'float'
     426        if isinstance(value, Decimal):
     427            return 'num'
     428        if isinstance(value, (date, time, datetime, timedelta)):
     429            return 'date'
     430        if isinstance(value, list):
     431            return '%s[]' % cls.guess_simple_base_type(value)
     432        if isinstance(value, tuple):
     433            simple_type = cls.simple_type
     434            typ = simple_type('record')
     435            guess = cls.guess_simple_type
     436            typ._get_attnames = lambda _self: AttrDict(
     437                (str(n + 1), simple_type(guess(v)))
     438                for n, v in enumerate(value))
     439            return typ
     440
     441    @classmethod
     442    def guess_simple_base_type(cls, value):
     443        """Try to guess the base type of a given array."""
     444        for v in value:
     445            if isinstance(v, list):
     446                typ = cls.guess_simple_base_type(v)
     447            else:
     448                typ = cls.guess_simple_type(v)
     449            if typ:
     450                return typ
     451
     452    def adapt_inline(self, value, nested=False):
     453        """Adapt a value that is put into the SQL and needs to be quoted."""
     454        if value is None:
     455            return 'NULL'
     456        if isinstance(value, Literal):
     457            return value
     458        if isinstance(value, Bytea):
     459            value = self.escape_bytea(value)
     460            if bytes is not str:  # Python >= 3.0
     461                value = value.decode('ascii')
     462        elif isinstance(value, Json):
     463            if value.encode:
     464                return value.encode()
     465            value = self.encode_json(value)
     466        elif isinstance(value, (datetime, date, time, timedelta)):
     467            value = str(value)
     468        if isinstance(value, basestring):
     469            value = self.escape_string(value)
     470            return "'%s'" % value
     471        if isinstance(value, bool):
     472            return 'true' if value else 'false'
     473        if isinstance(value, float):
     474            if isinf(value):
     475                return "'-Infinity'" if value < 0 else "'Infinity'"
     476            if isnan(value):
     477                return "'NaN'"
     478            return value
     479        if isinstance(value, (int, long, Decimal)):
     480            return value
     481        if isinstance(value, list):
     482            q = self.adapt_inline
     483            s = '[%s]' if nested else 'ARRAY[%s]'
     484            return s % ','.join(str(q(v, nested=True)) for v in value)
     485        if isinstance(value, tuple):
     486            q = self.adapt_inline
     487            return '(%s)' % ','.join(str(q(v)) for v in value)
     488        try:
     489            value = value.__pg_repr__()
     490        except AttributeError:
     491            raise InterfaceError(
     492                'Do not know how to adapt type %s' % type(value))
     493        if isinstance(value, (tuple, list)):
     494            value = self.adapt_inline(value)
     495        return value
     496
     497    def parameter_list(self):
     498        """Return a parameter list for parameters with known database types.
     499
     500        The list has an add(value, typ) method that will build up the
     501        list and return either the literal value or a placeholder.
     502        """
     503        params = _ParameterList()
     504        params.adapt = self.adapt
     505        return params
     506
     507    def format_query(self, command, values, types=None, inline=False):
     508        """Format a database query using the given values and types."""
     509        if inline and types:
     510            raise ValueError('Typed parameters must be sent separately')
     511        params = self.parameter_list()
     512        if isinstance(values, (list, tuple)):
     513            if inline:
     514                adapt = self.adapt_inline
     515                literals = [adapt(value) for value in values]
     516            else:
     517                add = params.add
     518                literals = []
     519                append = literals.append
     520                if types:
     521                    if (not isinstance(types, (list, tuple)) or
     522                            len(types) != len(values)):
     523                        raise TypeError('The values and types do not match')
     524                    for value, typ in zip(values, types):
     525                        append(add(value, typ))
     526                else:
     527                    for value in values:
     528                        append(add(value))
     529            command = command % tuple(literals)
     530        elif isinstance(values, dict):
     531            if inline:
     532                adapt = self.adapt_inline
     533                literals = dict((key, adapt(value))
     534                    for key, value in values.items())
     535            else:
     536                add = params.add
     537                literals = {}
     538                if types:
     539                    if (not isinstance(types, dict) or
     540                            len(types) < len(values)):
     541                        raise TypeError('The values and types do not match')
     542                    for key in sorted(values):
     543                        literals[key] = add(values[key], types[key])
     544                else:
     545                    for key in sorted(values):
     546                        literals[key] = add(values[key])
     547            command = command % literals
     548        else:
     549            raise TypeError('The values must be passed as tuple, list or dict')
     550        return command, params
    381551
    382552
     
    485655            for t in typ:
    486656                self[t] = cast
    487                 self.pop('_%s % t', None)
     657                self.pop('_%s' % t, None)
    488658
    489659    def reset(self, typ=None):
     
    492662        When no type is specified, all typecasts will be reset.
    493663        """
    494         defaults = self.defaults
    495664        if typ is None:
    496665            self.clear()
    497             self.update(defaults)
    498666        else:
    499667            if isinstance(typ, basestring):
    500668                typ = [typ]
    501669            for t in typ:
    502                 self.set(t, defaults.get(t))
     670                self.pop(t, None)
    503671
    504672    @classmethod
     
    522690            for t in typ:
    523691                defaults[t] = cast
    524                 defaults.pop('_%s % t', None)
     692                defaults.pop('_%s' % t, None)
    525693
    526694    def get_attnames(self, typ):
     
    603771        if oid in self:
    604772            return self[oid]
    605         simple = 'record' if relid else _simpletype[pgtype]
     773        simple = 'record' if relid else _simpletypes[pgtype]
    606774        typ = DbType(regtype if self._regtypes else simple)
    607775        typ.oid = oid
     
    622790                " typtype, typcategory, typdelim, typrelid"
    623791                " FROM pg_type WHERE oid=%s::regtype" %
    624                 (DB._adapt_qualified_param(key, 1),), (key,)).getresult()
     792                (_quote_if_unqualified('$1', key),), (key,)).getresult()
    625793        except ProgrammingError:
    626794            res = None
     
    675843            return value
    676844        return cast(value)
    677 
    678 
    679 class _Literal(str):
    680     """Wrapper class for literal SQL."""
    681845
    682846
     
    8601024# The actual PostGreSQL database connection interface:
    8611025
    862 class DB(_Adapt):
     1026class DB:
    8631027    """Wrapper class for the _pg connection type."""
    8641028
     
    8961060        self._privileges = {}
    8971061        self._args = args, kw
     1062        self.adapter = Adapter(self)
    8981063        self.dbtypes = DbTypes(self)
    8991064        db.set_cast_hook(self.dbtypes.typecast)
     
    9671132        """Create a human readable parameter list."""
    9681133        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
    969 
    970     @staticmethod
    971     def _adapt_qualified_param(name, param):
    972         """Quote parameter representing a qualified name.
    973 
    974         Escapes the name for use as an SQL parameter, unless the
    975         name contains a dot, in which case the name is ambiguous
    976         (could be a qualified name or just a name with a dot in it)
    977         and must be quoted manually by the caller.
    978 
    979         """
    980         if isinstance(param, int):
    981             param = "$%d" % param
    982         if isinstance(name, basestring) and '.' not in name:
    983             param = 'quote_ident(%s)' % (param,)
    984         return param
    9851134
    9861135    # Public methods
     
    12231372        return self.db.query(command)
    12241373
     1374    def query_formatted(self, command, parameters, types=None, inline=False):
     1375        """Execute a formatted SQL command string.
     1376
     1377        Similar to query, but using Python format placeholders of the form
     1378        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
     1379        The parameters must be passed as a tuple, list or dict.  You can
     1380        also pass a corresponding tuple, list or dict of database types in
     1381        order to format the parameters properly in case there is ambiguity.
     1382
     1383        If you set inline to True, the parameters will be sent to the database
     1384        embedded in the SQL command, otherwise they will be sent separately.
     1385        """
     1386        return self.query(*self.adapter.format_query(
     1387            command, parameters, types, inline))
     1388
    12251389    def pkey(self, table, composite=False, flush=False):
    12261390        """Get or set the primary key of a table.
     
    12481412                " WHERE i.indrelid=%s::regclass"
    12491413                " AND i.indisprimary ORDER BY a.attnum") % (
    1250                     self._adapt_qualified_param(table, 1),)
     1414                    _quote_if_unqualified('$1', table),)
    12511415            pkey = self.db.query(q, (table,)).getresult()
    12521416            if not pkey:
     
    13211485                " WHERE a.attrelid = %s::regclass AND %s"
    13221486                " AND NOT a.attisdropped ORDER BY a.attnum") % (
    1323                     self._adapt_qualified_param(table, 1), q)
     1487                    _quote_if_unqualified('$1', table), q)
    13241488            names = self.db.query(q, (table,)).getresult()
    13251489            types = self.dbtypes
     
    13481512        except KeyError:  # cache miss, ask the database
    13491513            q = "SELECT has_table_privilege(%s, $2)" % (
    1350                 self._adapt_qualified_param(table, 1),)
     1514                _quote_if_unqualified('$1', table),)
    13511515            q = self.db.query(q, (table, privilege))
    13521516            ret = q.getresult()[0][0] == self._make_bool(True)
     
    14051569                    'Differing number of items in keyname and row')
    14061570            row = dict(zip(keyname, row))
    1407         params = []
    1408         param = partial(self._adapt_param, params=params)
     1571        params = self.adapter.parameter_list()
     1572        adapt = params.add
    14091573        col = self.escape_identifier
    14101574        what = 'oid, *' if qoid else '*'
    14111575        where = ' AND '.join('%s = %s' % (
    1412             col(k), param(row[k], attnames[k])) for k in keyname)
     1576            col(k), adapt(row[k], attnames[k])) for k in keyname)
    14131577        if 'oid' in row:
    14141578            if qoid:
     
    14511615        attnames = self.get_attnames(table)
    14521616        qoid = _oid_key(table) if 'oid' in attnames else None
    1453         params = []
    1454         param = partial(self._adapt_param, params=params)
     1617        params = self.adapter.parameter_list()
     1618        adapt = params.add
    14551619        col = self.escape_identifier
    14561620        names, values = [], []
     
    14581622            if n in row:
    14591623                names.append(col(n))
    1460                 values.append(param(row[n], attnames[n]))
     1624                values.append(adapt(row[n], attnames[n]))
    14611625        if not names:
    14621626            raise _prg_error('No column found that can be inserted')
     
    15121676                else:
    15131677                    raise KeyError('Missing primary key in row')
    1514         params = []
    1515         param = partial(self._adapt_param, params=params)
     1678        params = self.adapter.parameter_list()
     1679        adapt = params.add
    15161680        col = self.escape_identifier
    15171681        where = ' AND '.join('%s = %s' % (
    1518             col(k), param(row[k], attnames[k])) for k in keyname)
     1682            col(k), adapt(row[k], attnames[k])) for k in keyname)
    15191683        if 'oid' in row:
    15201684            if qoid:
     
    15251689        for n in attnames:
    15261690            if n in row and n not in keyname:
    1527                 values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
     1691                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
    15281692        if not values:
    15291693            return row
     
    15951759        attnames = self.get_attnames(table)
    15961760        qoid = _oid_key(table) if 'oid' in attnames else None
    1597         params = []
    1598         param = partial(self._adapt_param,params=params)
     1761        params = self.adapter.parameter_list()
     1762        adapt = params.add
    15991763        col = self.escape_identifier
    16001764        names, values, updates = [], [], []
     
    16021766            if n in row:
    16031767                names.append(col(n))
    1604                 values.append(param(row[n], attnames[n]))
     1768                values.append(adapt(row[n], attnames[n]))
    16051769        names, values = ', '.join(names), ', '.join(values)
    16061770        try:
     
    17091873                else:
    17101874                    raise KeyError('Missing primary key in row')
    1711         params = []
    1712         param = partial(self._adapt_param, params=params)
     1875        params = self.adapter.parameter_list()
     1876        adapt = params.add
    17131877        col = self.escape_identifier
    17141878        where = ' AND '.join('%s = %s' % (
    1715             col(k), param(row[k], attnames[k])) for k in keyname)
     1879            col(k), adapt(row[k], attnames[k])) for k in keyname)
    17161880        if 'oid' in row:
    17171881            if qoid:
Note: See TracChangeset for help on using the changeset viewer.