Changeset 732 for trunk/pg.py


Ignore:
Timestamp:
Jan 13, 2016, 7:45:34 AM (4 years ago)
Author:
cito
Message:

Better handling of quoted identifiers

Methods like get(), update() did not handle quoted identifiers properly
(i.e. identifiers with spaces, mixed case characters or special characters).
This has been improved and tests have been added to make sure this works.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/pg.py

    r730 r732  
    5050# Auxiliary functions that are independent from a DB connection:
    5151
    52 def _quote_class_name(cl):
    53     """Quote a class name.
    54 
    55     Class names are always quoted unless they contain a dot.
    56     In this ambiguous case quotes must be added manually.
    57 
    58     """
    59     if '.' not in cl:
    60         cl = '"%s"' % cl
    61     return cl
    62 
    63 
    64 def _quote_class_param(cl, param):
    65     """Quote parameter representing a class name.
    66 
    67     The parameter is automatically quoted unless the class name contains a dot.
    68     In this ambiguous case quotes must be added manually.
    69 
    70     """
    71     if isinstance(param, int):
    72         param = "$%d" % param
    73     if '.' not in cl:
    74         param = 'quote_ident(%s)' % (param,)
    75     return param
    76 
    77 
    7852def _oid_key(cl):
    79     """Build oid key from qualified class name."""
     53    """Build oid key from a class name."""
    8054    return 'oid(%s)' % cl
    8155
     
    329303                print(s)
    330304
     305    def _escape_qualified_name(self, s):
     306        """Escape a qualified name.
     307
     308        Escapes the name for use as an SQL identifier, unless the
     309        name contains a dot, in which case the name is ambiguous
     310        (could be a qualified name or just a name with a dot in it)
     311        and must be quoted manually by the caller.
     312
     313        """
     314        if '.' not in s:
     315            s = self.escape_identifier(s)
     316        return s
     317
    331318    @staticmethod
    332319    def _make_bool(d):
     
    362349
    363350    def _prepare_bytea(self, d):
     351        """Prepare a bytea parameter."""
    364352        return self.escape_bytea(d)
    365353
     
    384372        return '$%d' % len(params)
    385373
     374    @staticmethod
     375    def _prepare_qualified_param(cl, param):
     376        """Quote parameter representing a qualified name.
     377
     378        Escapes the name for use as an SQL parameter, unless the
     379        name contains a dot, in which case the name is ambiguous
     380        (could be a qualified name or just a name with a dot in it)
     381        and must be quoted manually by the caller.
     382
     383        """
     384        if isinstance(param, int):
     385            param = "$%d" % param
     386        if '.' not in cl:
     387            param = 'quote_ident(%s)' % (param,)
     388        return param
     389
    386390    # Public methods
    387391
     
    508512                " AND NOT a.attisdropped"
    509513                " WHERE i.indrelid=%s::regclass"
    510                 " AND i.indisprimary" % _quote_class_param(cl, 1))
     514                " AND i.indisprimary" % self._prepare_qualified_param(cl, 1))
    511515            pkey = self.db.query(q, (cl,)).getresult()
    512516            if not pkey:
     
    573577                " AND NOT a.attisdropped") % (
    574578                    '::regtype' if self._regtypes else '',
    575                     _quote_class_param(cl, 1))
     579                    self._prepare_qualified_param(cl, 1))
    576580            names = self.db.query(q, (cl,)).getresult()
    577581            if not names:
     
    602606        except KeyError:  # cache miss, ask the database
    603607            q = "SELECT has_table_privilege(%s, $2)" % (
    604                 _quote_class_param(cl, 1),)
     608                self._prepare_qualified_param(cl, 1),)
    605609            q = self.db.query(q, (cl, privilege))
    606610            ret = q.getresult()[0][0] == self._make_bool(True)
     
    637641        params = []
    638642        param = partial(self._prepare_param, params=params)
     643        col = self.escape_identifier
    639644        # We want the oid for later updates if that isn't the key
    640645        if keyname == 'oid':
     
    652657                if len(keyname) > 1:
    653658                    raise _prg_error('Composite key needs dict as arg')
    654                 arg = dict([(k, arg) for k in keyname])
    655             what = ', '.join(attnames)
     659                arg = dict((k, arg) for k in keyname)
     660            what = ', '.join(col(k) for k in attnames)
    656661            where = ' AND '.join(['%s = %s'
    657                 % (k, param(arg[k], attnames[k])) for k in keyname])
     662                % (col(k), param(arg[k], attnames[k])) for k in keyname])
    658663        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
    659             what, _quote_class_name(cl), where)
     664            what, self._escape_qualified_name(cl), where)
    660665        self._do_debug(q, params)
    661666        res = self.db.query(q, params).dictresult()
     
    694699        params = []
    695700        param = partial(self._prepare_param, params=params)
     701        col = self.escape_identifier
    696702        names, values = [], []
    697703        for n in attnames:
    698704            if n != 'oid' and n in d:
    699                 names.append('"%s"' % n)
     705                names.append(col(n))
    700706                values.append(param(d[n], attnames[n]))
    701707        names, values = ', '.join(names), ', '.join(values)
     
    706712            ret = ''
    707713        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (
    708             _quote_class_name(cl), names, values, ret)
     714            self._escape_qualified_name(cl), names, values, ret)
    709715        self._do_debug(q, params)
    710716        res = self.db.query(q, params)
     
    754760        params = []
    755761        param = partial(self._prepare_param, params=params)
     762        col = self.escape_identifier
    756763        if qoid in d:
    757764            where = 'oid = %s' % param(d[qoid], 'int')
     
    766773            try:
    767774                where = ' AND '.join(['%s = %s'
    768                     % (k, param(d[k], attnames[k])) for k in keyname])
     775                    % (col(k), param(d[k], attnames[k])) for k in keyname])
    769776            except KeyError:
    770777                raise _prg_error('Update needs primary key or oid.')
     
    772779        for n in attnames:
    773780            if n in d and n not in keyname:
    774                 values.append('%s = %s' % (n, param(d[n], attnames[n])))
     781                values.append('%s = %s' % (col(n), param(d[n], attnames[n])))
    775782        if not values:
    776783            return d
     
    782789            ret = ''
    783790        q = 'UPDATE %s SET %s WHERE %s%s' % (
    784             _quote_class_name(cl), values, where, ret)
     791            self._escape_qualified_name(cl), values, where, ret)
    785792        self._do_debug(q, params)
    786793        res = self.db.query(q, params)
     
    859866                keyname = (keyname,)
    860867            attnames = self.get_attnames(cl)
     868            col = self.escape_identifier
    861869            try:
    862870                where = ' AND '.join(['%s = %s'
    863                     % (k, param(d[k], attnames[k])) for k in keyname])
     871                    % (col(k), param(d[k], attnames[k])) for k in keyname])
    864872            except KeyError:
    865873                raise _prg_error('Delete needs primary key or oid.')
    866         q = 'DELETE FROM %s WHERE %s' % (_quote_class_name(cl), where)
     874        q = 'DELETE FROM %s WHERE %s' % (
     875            self._escape_qualified_name(cl), where)
    867876        self._do_debug(q, params)
    868877        return int(self.db.query(q, params))
Note: See TracChangeset for help on using the changeset viewer.