Changeset 781 for trunk/pg.py


Ignore:
Timestamp:
Jan 25, 2016, 3:44:52 PM (4 years ago)
Author:
cito
Message:

Add full support for PostgreSQL array types

At the core of this patch is a fast parser for the peculiar syntax of
literal array expressions in PostgreSQL that was added to the C module.
This is not trivial, because PostgreSQL arrays can be multidimensional
and the syntax is different from Python and SQL expressions.

The Python pg and pgdb modules make use of this parser so that they can
return database columns containing PostgreSQL arrays to Python as lists.
Also added quoting methods that allow passing PostgreSQL arrays as lists
to insert()/update() and execute/executemany(). These methods are simpler
and were implemented in Python but needed support from the regex module.

The patch also adds makes getresult() in pg automatically return bytea
values in unescaped form as bytes strings. Before, it was necessary to
call unescape_bytea manually. The pgdb module did this already.

The patch includes some more refactorings and simplifications regarding
the quoting and casting in pg and pgdb.

Some references to antique PostgreSQL types that are not used any more
in the supported PostgreSQL versions have been removed.

Also added documentation and tests for the new features.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/pg.py

    r774 r781  
    3939from functools import partial
    4040from operator import itemgetter
     41from re import compile as regex
    4142from json import loads as jsondecode, dumps as jsonencode
    4243
     
    132133
    133134
    134 # Auxiliary functions that are independent from a DB connection:
     135# Auxiliary classes and functions that are independent from a DB connection:
    135136
    136137def _oid_key(table):
     
    139140
    140141
    141 def _simpletype(typ):
    142     """Determine a simplified name a pg_type name."""
    143     if typ.startswith('bool'):
    144         return 'bool'
    145     if typ.startswith(('abstime', 'date', 'interval', 'timestamp')):
    146         return 'date'
    147     if typ.startswith(('cid', 'oid', 'int', 'xid')):
    148         return 'int'
    149     if typ.startswith('float'):
    150         return 'float'
    151     if typ.startswith('numeric'):
    152         return 'num'
    153     if typ.startswith('money'):
    154         return 'money'
    155     if typ.startswith('bytea'):
    156         return 'bytea'
    157     if typ.startswith('json'):
    158         return 'json'
    159     return 'text'
     142class _SimpleType(dict):
     143    """Dictionary mapping pg_type names to simple type names."""
     144
     145    _types = {'bool': 'bool',
     146        'bytea': 'bytea',
     147        'date': 'date interval time timetz timestamp timestamptz'
     148            ' abstime reltime',  # these are very old
     149        'float': 'float4 float8',
     150        'int': 'cid int2 int4 int8 oid xid',
     151        'json': 'json jsonb',
     152        'num': 'numeric',
     153        'money': 'money',
     154        'text': 'bpchar char name text varchar'}
     155
     156    def __init__(self):
     157        for typ, keys in self._types.items():
     158            for key in keys.split():
     159                self[key] = typ
     160                self['_%s' % key] = '%s[]' % typ
     161
     162    @staticmethod
     163    def __missing__(key):
     164        return 'text'
     165
     166_simpletype = _SimpleType()
     167
     168
     169class _Literal(str):
     170    """Wrapper class for literal SQL."""
    160171
    161172
     
    414425        """Print a debug message"""
    415426        if self.debug:
    416             s = '\n'.join(args)
     427            s = '\n'.join(str(arg) for arg in args)
    417428            if isinstance(self.debug, basestring):
    418429                print(self.debug % s)
     
    459470            return None
    460471        if isinstance(d, basestring) and d.lower() in self._date_literals:
    461             raise ValueError
     472            return _Literal(d)
    462473        return d
    463474
     
    465476        ' int2 int4 int8 float4 float8 numeric money'.split())
    466477
    467     def _prepare_num(self, d):
     478    @staticmethod
     479    def _prepare_num(d):
    468480        """Prepare a numeric parameter."""
    469481        if not d and d != 0:
     
    471483        return d
    472484
     485    _prepare_int = _prepare_float = _prepare_money = _prepare_num
     486
    473487    def _prepare_bytea(self, d):
    474488        """Prepare a bytea parameter."""
     
    477491    def _prepare_json(self, d):
    478492        """Prepare a json parameter."""
     493        if not d:
     494            return None
     495        if isinstance(d, basestring):
     496            return d
    479497        return self.encode_json(d)
    480498
    481     _prepare_funcs = dict(  # quote methods for each type
    482         bool=_prepare_bool, date=_prepare_date,
    483         int=_prepare_num, num=_prepare_num, float=_prepare_num,
    484         money=_prepare_num, bytea=_prepare_bytea, json=_prepare_json)
     499    _re_array_escape = regex(r'(["\\])')
     500    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
     501
     502    def _prepare_bool_array(self, d):
     503        """Prepare a bool array parameter."""
     504        if isinstance(d, list):
     505            return '{%s}' % ','.join(self._prepare_bool_array(v) for v in d)
     506        if d is None:
     507            return 'null'
     508        if isinstance(d, basestring):
     509            if not d:
     510                return 'null'
     511            d = d.lower() in self._bool_true_values
     512        return 't' if d else 'f'
     513
     514    def _prepare_num_array(self, d):
     515        """Prepare a numeric array parameter."""
     516        if isinstance(d, list):
     517            return '{%s}' % ','.join(self._prepare_num_array(v) for v in d)
     518        if not d and d != 0:
     519            return 'null'
     520        return str(d)
     521
     522    _prepare_int_array = _prepare_float_array = _prepare_money_array = \
     523            _prepare_num_array
     524
     525    def _prepare_text_array(self, d):
     526        """Prepare a text array parameter."""
     527        if isinstance(d, list):
     528            return '{%s}' % ','.join(self._prepare_text_array(v) for v in d)
     529        if d is None:
     530            return 'null'
     531        if not d:
     532            return '""'
     533        d = str(d)
     534        if self._re_array_quote.search(d):
     535            d = '"%s"' % self._re_array_escape.sub(r'\\\1', d)
     536        return d
     537
     538    def _prepare_bytea_array(self, d):
     539        """Prepare a bytea array parameter."""
     540        if isinstance(d, list):
     541            return '{%s}' % ','.join(self._prepare_bytea_array(v) for v in d)
     542        if d is None:
     543            return 'null'
     544        return self.escape_bytea(d).replace('\\', '\\\\')
     545
     546    def _prepare_json_array(self, d):
     547        """Prepare a json array parameter."""
     548        if isinstance(d, list):
     549            return '{%s}' % ','.join(self._prepare_json_array(v) for v in d)
     550        if not d:
     551            return 'null'
     552        if not isinstance(d, basestring):
     553            d = self.encode_json(d)
     554        if self._re_array_quote.search(d):
     555            d = '"%s"' % self._re_array_escape.sub(r'\\\1', d)
     556        return d
    485557
    486558    def _prepare_param(self, value, typ, params):
    487559        """Prepare and add a parameter to the list."""
     560        if isinstance(value, _Literal):
     561            return value
    488562        if value is not None and typ != 'text':
    489             prepare = self._prepare_funcs[typ]
    490             try:
    491                 value = prepare(self, value)
    492             except ValueError:
     563            if typ.endswith('[]'):
     564                if isinstance(value, list):
     565                    prepare = getattr(self, '_prepare_%s_array' % typ[:-2])
     566                    value = prepare(value)
     567                elif isinstance(value, basestring):
     568                    value = value.strip()
     569                    if not value.startswith('{') or not value.endswith('}'):
     570                        if value[:5].lower() == 'array':
     571                            value = value[5:].lstrip()
     572                        if value.startswith('[') and value.endswith(']'):
     573                            value = _Literal('ARRAY%s' % value)
     574                        else:
     575                            raise ValueError(
     576                                'Invalid array expression: %s' % value)
     577                else:
     578                    raise ValueError('Invalid array parameter: %s' % value)
     579            else:
     580                prepare = getattr(self, '_prepare_%s' % typ)
     581                value = prepare(value)
     582            if isinstance(value, _Literal):
    493583                return value
    494584        params.append(value)
     
    726816            self.db.query(q)
    727817
    728     def query(self, qstr, *args):
     818    def query(self, command, *args):
    729819        """Execute a SQL command string.
    730820
     
    748838        if not self.db:
    749839            raise _int_error('Connection is not valid')
    750         self._do_debug(qstr)
    751         return self.db.query(qstr, args)
     840        if args:
     841            self._do_debug(command, args)
     842            return self.db.query(command, args)
     843        self._do_debug(command)
     844        return self.db.query(command)
    752845
    753846    def pkey(self, table, composite=False, flush=False):
     
    850943            names = self.db.query(q, (table,)).getresult()
    851944            if not self._regtypes:
    852                 names = ((name, _simpletype(typ)) for name, typ in names)
     945                names = ((name, _simpletype[typ]) for name, typ in names)
    853946            names = AttrDict(names)
    854947            attnames[table] = names  # cache it
     
    9511044            if qoid and n == 'oid':
    9521045                n = qoid
    953             elif value is not None and attnames.get(n) == 'bytea':
    954                 value = self.unescape_bytea(value)
    9551046            row[n] = value
    9561047        return row
     
    9861077                names.append(col(n))
    9871078                values.append(param(row[n], attnames[n]))
     1079        if not names:
     1080            raise _prg_error('No column found that can be inserted')
    9881081        names, values = ', '.join(names), ', '.join(values)
    9891082        ret = 'oid, *' if qoid else '*'
     
    9971090                if qoid and n == 'oid':
    9981091                    n = qoid
    999                 elif value is not None and attnames.get(n) == 'bytea':
    1000                     value = self.unescape_bytea(value)
    10011092                row[n] = value
    10021093        return row
     
    10661157                if qoid and n == 'oid':
    10671158                    n = qoid
    1068                 elif value is not None and attnames.get(n) == 'bytea':
    1069                     value = self.unescape_bytea(value)
    10701159                row[n] = value
    10711160        return row
     
    11691258                if qoid and n == 'oid':
    11701259                    n = qoid
    1171                 elif value is not None and attnames.get(n) == 'bytea':
    1172                     value = self.unescape_bytea(value)
    11731260                row[n] = value
    11741261        else:
Note: See TracChangeset for help on using the changeset viewer.