Changeset 793 for trunk


Ignore:
Timestamp:
Jan 28, 2016, 3:07:41 PM (4 years ago)
Author:
cito
Message:

Improve quoting and typecasting in the pg module

Larger refactoring of the code for adapting and typecasting in the pg module.
Things are now a lot cleaner and clearer.

The _Adapt class is responsible for all adapting of Python objects to their
PostgreSQL equivalents when sending data to the database. The typecasting
from PostgreSQL on output happens in the C module, except for the typecasting
of records which is new and provided by the _CastRecord class.

The classic module also did not work properly when regular type names were
switched on with use_regtypes(True), since the adapting of types relied on
the PyGreSQL type names. This has been solved by adding a new _PgType class
that is essentially the old type name, but augmented with all the necessary
information necessary to adapt types, particularly record types.

All tests in test_classic_dbwrapper now run twice, using opposite settings
for the various configuration settings like use_bool() or use_regtypes(),
in order to make sure that no internal functions rely on default settings.

Location:
trunk
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/pg.py

    r785 r793  
    165165
    166166_simpletype = _SimpleType()
     167
     168
     169class _Adapt:
     170    """Mixin providing methods for adapting records and record elements.
     171
     172    This is used when passing values from one of the higher level DB
     173    methods as parameters for a query.
     174
     175    This class must be mixed in to a connection class, because it needs
     176    connection specific methods such as escape_bytea().
     177    """
     178
     179    _bool_true_values = frozenset('t true 1 y yes on'.split())
     180
     181    _date_literals = frozenset('current_date current_time'
     182        ' current_timestamp localtime localtimestamp'.split())
     183
     184    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
     185    _re_record_quote = regex(r'[(,"\\]')
     186    _re_array_escape = _re_record_escape = regex(r'(["\\])')
     187
     188    @classmethod
     189    def _adapt_bool(cls, v):
     190        """Adapt a boolean parameter."""
     191        if isinstance(v, basestring):
     192            if not v:
     193                return None
     194            v = v.lower() in cls._bool_true_values
     195        return 't' if v else 'f'
     196
     197    @classmethod
     198    def _adapt_date(cls, v):
     199        """Adapt a date parameter."""
     200        if not v:
     201            return None
     202        if isinstance(v, basestring) and v.lower() in cls._date_literals:
     203            return _Literal(v)
     204        return v
     205
     206    @staticmethod
     207    def _adapt_num(v):
     208        """Adapt a numeric parameter."""
     209        if not v and v != 0:
     210            return None
     211        return v
     212
     213    _adapt_int = _adapt_float = _adapt_money = _adapt_num
     214
     215    def _adapt_bytea(self, v):
     216        """Adapt a bytea parameter."""
     217        return self.escape_bytea(v)
     218
     219    def _adapt_json(self, v):
     220        """Adapt a json parameter."""
     221        if not v:
     222            return None
     223        if isinstance(v, basestring):
     224            return v
     225        return self.encode_json(v)
     226
     227    @classmethod
     228    def _adapt_text_array(cls, v):
     229        """Adapt a text type array parameter."""
     230        if isinstance(v, list):
     231            adapt = cls._adapt_text_array
     232            return '{%s}' % ','.join(adapt(v) for v in v)
     233        if v is None:
     234            return 'null'
     235        if not v:
     236            return '""'
     237        v = str(v)
     238        if cls._re_array_quote.search(v):
     239            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
     240        return v
     241
     242    _adapt_date_array = _adapt_text_array
     243
     244    @classmethod
     245    def _adapt_bool_array(cls, v):
     246        """Adapt a boolean array parameter."""
     247        if isinstance(v, list):
     248            adapt = cls._adapt_bool_array
     249            return '{%s}' % ','.join(adapt(v) for v in v)
     250        if v is None:
     251            return 'null'
     252        if isinstance(v, basestring):
     253            if not v:
     254                return 'null'
     255            v = v.lower() in cls._bool_true_values
     256        return 't' if v else 'f'
     257
     258    @classmethod
     259    def _adapt_num_array(cls, v):
     260        """Adapt a numeric array parameter."""
     261        if isinstance(v, list):
     262            adapt = cls._adapt_num_array
     263            return '{%s}' % ','.join(adapt(v) for v in v)
     264        if not v and v != 0:
     265            return 'null'
     266        return str(v)
     267
     268    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
     269            _adapt_num_array
     270
     271    def _adapt_bytea_array(self, v):
     272        """Adapt a bytea array parameter."""
     273        if isinstance(v, list):
     274            return b'{' + b','.join(
     275                self._adapt_bytea_array(v) for v in v) + b'}'
     276        if v is None:
     277            return b'null'
     278        return self.escape_bytea(v).replace(b'\\', b'\\\\')
     279
     280    def _adapt_json_array(self, v):
     281        """Adapt a json array parameter."""
     282        if isinstance(v, list):
     283            adapt = self._adapt_json_array
     284            return '{%s}' % ','.join(adapt(v) for v in v)
     285        if not v:
     286            return 'null'
     287        if not isinstance(v, basestring):
     288            v = self.encode_json(v)
     289        if self._re_array_quote.search(v):
     290            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
     291        return v
     292
     293    def _adapt_record(self, v, typ):
     294        """Adapt a record parameter with given type."""
     295        typ = typ.attnames.values()
     296        if len(typ) != len(v):
     297            raise TypeError('Record parameter %s has wrong size' % v)
     298        return '(%s)' % ','.join(getattr(self,
     299            '_adapt_record_%s' % t.simple)(v) for v, t in zip(v, typ))
     300
     301    @classmethod
     302    def _adapt_record_text(cls, v):
     303        """Adapt a text type record component."""
     304        if v is None:
     305            return ''
     306        if not v:
     307            return '""'
     308        v = str(v)
     309        if cls._re_record_quote.search(v):
     310            v = '"%s"' % cls._re_record_escape.sub(r'\\\1', v)
     311        return v
     312
     313    _adapt_record_date = _adapt_record_text
     314
     315    @classmethod
     316    def _adapt_record_bool(cls, v):
     317        """Adapt a boolean record component."""
     318        if v is None:
     319            return ''
     320        if isinstance(v, basestring):
     321            if not v:
     322                return ''
     323            v = v.lower() in cls._bool_true_values
     324        return 't' if v else 'f'
     325
     326    @staticmethod
     327    def _adapt_record_num(v):
     328        """Adapt a numeric record component."""
     329        if not v and v != 0:
     330            return ''
     331        return str(v)
     332
     333    _adapt_record_int = _adapt_record_float = _adapt_record_money = \
     334        _adapt_record_num
     335
     336    def _adapt_record_bytea(self, v):
     337        if v is None:
     338            return ''
     339        v = self.escape_bytea(v)
     340        if bytes is not str and isinstance(v, bytes):
     341            v = v.decode('ascii')
     342        return v.replace('\\', '\\\\')
     343
     344    def _adapt_record_json(self, v):
     345        """Adapt a bytea record component."""
     346        if not v:
     347            return ''
     348        if not isinstance(v, basestring):
     349            v = self.encode_json(v)
     350        if self._re_array_quote.search(v):
     351            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
     352        return v
     353
     354    def _adapt_param(self, value, typ, params):
     355        """Adapt and add a parameter to the list."""
     356        if isinstance(value, _Literal):
     357            return value
     358        if value is not None:
     359            simple = typ.simple
     360            if simple == 'text':
     361                pass
     362            elif simple == 'record':
     363                if isinstance(value, tuple):
     364                    value = self._adapt_record(value, typ)
     365            elif simple.endswith('[]'):
     366                if isinstance(value, list):
     367                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
     368                    value = adapt(value)
     369            else:
     370                adapt = getattr(self, '_adapt_%s' % simple)
     371                value = adapt(value)
     372                if isinstance(value, _Literal):
     373                    return value
     374        params.append(value)
     375        return '$%d' % len(params)
     376
     377
     378class _CastRecord:
     379    """Class providing methods for casting records and record elements.
     380
     381    This is needed when getting result values from one of the higher level DB
     382    methods, since the lower level query method only casts the other types.
     383    """
     384
     385    @staticmethod
     386    def cast_bool(v):
     387        if not get_bool():
     388            return v
     389        return v[0] == 't'
     390
     391    @staticmethod
     392    def cast_bytea(v):
     393        return unescape_bytea(v)
     394
     395    @staticmethod
     396    def cast_float(v):
     397        return float(v)
     398
     399    @staticmethod
     400    def cast_int(v):
     401        return int(v)
     402
     403    @staticmethod
     404    def cast_json(v):
     405        cast = get_jsondecode()
     406        if not cast:
     407            return v
     408        return cast(v)
     409
     410    @staticmethod
     411    def cast_num(v):
     412        return (get_decimal() or float)(v)
     413
     414    @staticmethod
     415    def cast_money(v):
     416        point = get_decimal_point()
     417        if not point:
     418            return v
     419        if point != '.':
     420            v = v.replace(point, '.')
     421        v = v.replace('(', '-')
     422        v = ''.join(c for c in v if c.isdigit() or c in '.-')
     423        return (get_decimal() or float)(v)
     424
     425    @classmethod
     426    def cast(cls, v, typ):
     427        types = typ.attnames.values()
     428        cast = [getattr(cls, 'cast_%s' % t.simple, None) for t in types]
     429        v = cast_record(v, cast)
     430        return typ.namedtuple(*v)
     431
     432
     433class _PgType(str):
     434    """Class augmenting the simple type name with additional info."""
     435
     436    _num_types = frozenset('int float num money'
     437        ' int2 int4 int8 float4 float8 numeric money'.split())
     438
     439    @classmethod
     440    def create(cls, db, pgtype, regtype, typrelid):
     441        """Create a PostgreSQL type name with additional info."""
     442        simple = 'record' if typrelid else _simpletype[pgtype]
     443        self = cls(regtype if db._regtypes else simple)
     444        self.db = db
     445        self.simple = simple
     446        self.pgtype = pgtype
     447        self.regtype = regtype
     448        self.typrelid = typrelid
     449        self._attnames = self._namedtuple = None
     450        return self
     451
     452    @property
     453    def attnames(self):
     454        """Get names and types of the fields of a composite type."""
     455        if not self.typrelid:
     456            return None
     457        if not self._attnames:
     458            self._attnames = self.db.get_attnames(self.typrelid)
     459        return self._attnames
     460
     461    @property
     462    def namedtuple(self):
     463        """Return named tuple class representing a composite type."""
     464        if not self._namedtuple:
     465            self._namedtuple = namedtuple(self, self.attnames)
     466        return self._namedtuple
     467
     468    def cast(self, value):
     469        if value is not None and self.typrelid:
     470            value = _CastRecord.cast(value, self)
     471        return value
    167472
    168473
     
    350655# The actual PostGreSQL database connection interface:
    351656
    352 class DB(object):
     657class DB(_Adapt):
    353658    """Wrapper class for the _pg connection type."""
    354659
     
    452757        return bool(d) if get_bool() else ('t' if d else 'f')
    453758
    454     _bool_true_values = frozenset('t true 1 y yes on'.split())
    455 
    456     def _prepare_bool(self, d):
    457         """Prepare a boolean parameter."""
    458         if isinstance(d, basestring):
    459             if not d:
    460                 return None
    461             d = d.lower() in self._bool_true_values
    462         return 't' if d else 'f'
    463 
    464     _date_literals = frozenset('current_date current_time'
    465         ' current_timestamp localtime localtimestamp'.split())
    466 
    467     def _prepare_date(self, d):
    468         """Prepare a date parameter."""
    469         if not d:
    470             return None
    471         if isinstance(d, basestring) and d.lower() in self._date_literals:
    472             return _Literal(d)
    473         return d
    474 
    475     _num_types = frozenset('int float num money'
    476         ' int2 int4 int8 float4 float8 numeric money'.split())
    477 
    478     @staticmethod
    479     def _prepare_num(d):
    480         """Prepare a numeric parameter."""
    481         if not d and d != 0:
    482             return None
    483         return d
    484 
    485     _prepare_int = _prepare_float = _prepare_money = _prepare_num
    486 
    487     def _prepare_bytea(self, d):
    488         """Prepare a bytea parameter."""
    489         return self.escape_bytea(d)
    490 
    491     def _prepare_json(self, d):
    492         """Prepare a json parameter."""
    493         if not d:
    494             return None
    495         if isinstance(d, basestring):
    496             return d
    497         return self.encode_json(d)
    498 
    499     _re_array_escape = regex(r'(["\\])')
    500     _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
    501 
    502     def _prepare_bool_array(self, d):
    503         """Prepare a bool array parameter."""
    504         if isinstance(d, list):
    505             return '{%s}' % ','.join(self._prepare_bool_array(v) for v in d)
    506         if d is None:
    507             return 'null'
    508         if isinstance(d, basestring):
    509             if not d:
    510                 return 'null'
    511             d = d.lower() in self._bool_true_values
    512         return 't' if d else 'f'
    513 
    514     def _prepare_num_array(self, d):
    515         """Prepare a numeric array parameter."""
    516         if isinstance(d, list):
    517             return '{%s}' % ','.join(self._prepare_num_array(v) for v in d)
    518         if not d and d != 0:
    519             return 'null'
    520         return str(d)
    521 
    522     _prepare_int_array = _prepare_float_array = _prepare_money_array = \
    523             _prepare_num_array
    524 
    525     def _prepare_text_array(self, d):
    526         """Prepare a text array parameter."""
    527         if isinstance(d, list):
    528             return '{%s}' % ','.join(self._prepare_text_array(v) for v in d)
    529         if d is None:
    530             return 'null'
    531         if not d:
    532             return '""'
    533         d = str(d)
    534         if self._re_array_quote.search(d):
    535             d = '"%s"' % self._re_array_escape.sub(r'\\\1', d)
    536         return d
    537 
    538     def _prepare_bytea_array(self, d):
    539         """Prepare a bytea array parameter."""
    540         if isinstance(d, list):
    541             return b'{' + b','.join(
    542                 self._prepare_bytea_array(v) for v in d) + b'}'
    543         if d is None:
    544             return b'null'
    545         return self.escape_bytea(d).replace(b'\\', b'\\\\')
    546 
    547     def _prepare_json_array(self, d):
    548         """Prepare a json array parameter."""
    549         if isinstance(d, list):
    550             return '{%s}' % ','.join(self._prepare_json_array(v) for v in d)
    551         if not d:
    552             return 'null'
    553         if not isinstance(d, basestring):
    554             d = self.encode_json(d)
    555         if self._re_array_quote.search(d):
    556             d = '"%s"' % self._re_array_escape.sub(r'\\\1', d)
    557         return d
    558 
    559     def _prepare_param(self, value, typ, params):
    560         """Prepare and add a parameter to the list."""
    561         if isinstance(value, _Literal):
    562             return value
    563         if value is not None and typ != 'text':
    564             if typ.endswith('[]'):
    565                 if isinstance(value, list):
    566                     prepare = getattr(self, '_prepare_%s_array' % typ[:-2])
    567                     value = prepare(value)
    568                 elif isinstance(value, basestring):
    569                     value = value.strip()
    570                     if not value.startswith('{') or not value.endswith('}'):
    571                         if value[:5].lower() == 'array':
    572                             value = value[5:].lstrip()
    573                         if value.startswith('[') and value.endswith(']'):
    574                             value = _Literal('ARRAY%s' % value)
    575                         else:
    576                             raise ValueError(
    577                                 'Invalid array expression: %s' % value)
    578                 else:
    579                     raise ValueError('Invalid array parameter: %s' % value)
    580             else:
    581                 prepare = getattr(self, '_prepare_%s' % typ)
    582                 value = prepare(value)
    583             if isinstance(value, _Literal):
    584                 return value
    585         params.append(value)
    586         return '$%d' % len(params)
    587 
    588759    def _list_params(self, params):
    589760        """Create a human readable parameter list."""
     
    591762
    592763    @staticmethod
    593     def _prepare_qualified_param(name, param):
     764    def _adapt_qualified_param(name, param):
    594765        """Quote parameter representing a qualified name.
    595766
     
    602773        if isinstance(param, int):
    603774            param = "$%d" % param
    604         if '.' not in name:
     775        if isinstance(name, basestring) and '.' not in name:
    605776            param = 'quote_ident(%s)' % (param,)
    606777        return param
     
    8701041                " WHERE i.indrelid=%s::regclass"
    8711042                " AND i.indisprimary ORDER BY a.attnum") % (
    872                     self._prepare_qualified_param(table, 1),)
     1043                    self._adapt_qualified_param(table, 1),)
    8731044            pkey = self.db.query(q, (table,)).getresult()
    8741045            if not pkey:
     
    9341105            names = attnames[table]
    9351106        except KeyError:  # cache miss, check the database
    936             q = ("SELECT a.attname, t.typname%s"
     1107            q = ("SELECT a.attname, t.typname, t.typname::regtype, t.typrelid"
    9371108                " FROM pg_attribute a"
    9381109                " JOIN pg_type t ON t.oid = a.atttypid"
     
    9401111                " AND (a.attnum > 0 OR a.attname = 'oid')"
    9411112                " AND NOT a.attisdropped ORDER BY a.attnum") % (
    942                     '::regtype' if self._regtypes else '',
    943                     self._prepare_qualified_param(table, 1))
     1113                    self._adapt_qualified_param(table, 1))
    9441114            names = self.db.query(q, (table,)).getresult()
    945             if not self._regtypes:
    946                 names = ((name, _simpletype[typ]) for name, typ in names)
     1115            names = ((name, _PgType.create(self, pgtype, regtype, typrelid))
     1116                for name, pgtype, regtype, typrelid in names)
    9471117            names = AttrDict(names)
    9481118            attnames[table] = names  # cache it
     
    9671137        except KeyError:  # cache miss, ask the database
    9681138            q = "SELECT has_table_privilege(%s, $2)" % (
    969                 self._prepare_qualified_param(table, 1),)
     1139                self._adapt_qualified_param(table, 1),)
    9701140            q = self.db.query(q, (table, privilege))
    9711141            ret = q.getresult()[0][0] == self._make_bool(True)
     
    10251195            row = dict(zip(keyname, row))
    10261196        params = []
    1027         param = partial(self._prepare_param, params=params)
     1197        param = partial(self._adapt_param, params=params)
    10281198        col = self.escape_identifier
    10291199        what = 'oid, *' if qoid else '*'
     
    10451215            if qoid and n == 'oid':
    10461216                n = qoid
     1217            else:
     1218                value = attnames[n].cast(value)
    10471219            row[n] = value
    10481220        return row
     
    10711243        qoid = _oid_key(table) if 'oid' in attnames else None
    10721244        params = []
    1073         param = partial(self._prepare_param, params=params)
     1245        param = partial(self._adapt_param, params=params)
    10741246        col = self.escape_identifier
    10751247        names, values = [], []
     
    10911263                if qoid and n == 'oid':
    10921264                    n = qoid
     1265                else:
     1266                    value = attnames[n].cast(value)
    10931267                row[n] = value
    10941268        return row
     
    11321306                    raise KeyError('Missing primary key in row')
    11331307        params = []
    1134         param = partial(self._prepare_param, params=params)
     1308        param = partial(self._adapt_param, params=params)
    11351309        col = self.escape_identifier
    11361310        where = ' AND '.join('%s = %s' % (
     
    11581332                if qoid and n == 'oid':
    11591333                    n = qoid
     1334                else:
     1335                    value = attnames[n].cast(value)
    11601336                row[n] = value
    11611337        return row
     
    12151391        qoid = _oid_key(table) if 'oid' in attnames else None
    12161392        params = []
    1217         param = partial(self._prepare_param,params=params)
     1393        param = partial(self._adapt_param,params=params)
    12181394        col = self.escape_identifier
    12191395        names, values, updates = [], [], []
     
    12591435                if qoid and n == 'oid':
    12601436                    n = qoid
     1437                else:
     1438                    value = attnames[n].cast(value)
    12611439                row[n] = value
    12621440        else:
     
    12791457            if n == 'oid':
    12801458                continue
    1281             if t in self._num_types:
     1459            t = t.simple
     1460            if t in _PgType._num_types:
    12821461                row[n] = 0
    12831462            elif t == 'bool':
     
    13281507                    raise KeyError('Missing primary key in row')
    13291508        params = []
    1330         param = partial(self._prepare_param, params=params)
     1509        param = partial(self._adapt_param, params=params)
    13311510        col = self.escape_identifier
    13321511        where = ' AND '.join('%s = %s' % (
  • trunk/tests/test_classic_dbwrapper.py

    r781 r793  
    380380    cls_set_up = False
    381381
     382    regtypes = None
     383
    382384    @classmethod
    383385    def setUpClass(cls):
     
    402404        self.assertTrue(self.cls_set_up)
    403405        self.db = DB()
     406        if self.regtypes is None:
     407            self.regtypes = self.db.use_regtypes()
     408        else:
     409            self.db.use_regtypes(self.regtypes)
    404410        query = self.db.query
    405411        query('set client_encoding=utf8')
     
    986992        r = get_attnames('test')
    987993        self.assertIsInstance(r, dict)
    988         self.assertEqual(r, dict(
    989             i2='int', i4='int', i8='int', d='num',
    990             f4='float', f8='float', m='money',
    991             v4='text', c4='text', t='text'))
     994        if self.regtypes:
     995            self.assertEqual(r, dict(
     996                i2='smallint', i4='integer', i8='bigint', d='numeric',
     997                f4='real', f8='double precision', m='money',
     998                v4='character varying', c4='character', t='text'))
     999        else:
     1000            self.assertEqual(r, dict(
     1001                i2='int', i4='int', i8='int', d='num',
     1002                f4='float', f8='float', m='money',
     1003                v4='text', c4='text', t='text'))
    9921004        self.createTable('test_table',
    9931005            'n int, alpha smallint, beta bool,'
     
    9951007        r = get_attnames('test_table')
    9961008        self.assertIsInstance(r, dict)
    997         self.assertEqual(r, dict(
    998             n='int', alpha='int', beta='bool',
    999             gamma='text', tau='text', v='text'))
     1009        if self.regtypes:
     1010            self.assertEqual(r, dict(
     1011                n='integer', alpha='smallint', beta='boolean',
     1012                gamma='character', tau='text', v='character varying'))
     1013        else:
     1014            self.assertEqual(r, dict(
     1015                n='int', alpha='int', beta='bool',
     1016                gamma='text', tau='text', v='text'))
    10001017
    10011018    def testGetAttnamesWithQuotes(self):
     
    10061023        r = get_attnames(table)
    10071024        self.assertIsInstance(r, dict)
    1008         self.assertEqual(r, {
    1009             'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
     1025        if self.regtypes:
     1026            self.assertEqual(r, {
     1027                'Prime!': 'smallint', 'much space': 'integer',
     1028                'Questions?': 'text'})
     1029        else:
     1030            self.assertEqual(r, {
     1031                'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
    10101032        table = 'yet another test table for get_attnames()'
    10111033        self.createTable(table,
    10121034            'a smallint, b integer, c bigint,'
    1013             ' e numeric, f float, f2 double precision, m money,'
     1035            ' e numeric, f real, f2 double precision, m money,'
    10141036            ' x smallint, y smallint, z smallint,'
    10151037            ' Normal_NaMe smallint, "Special Name" smallint,'
     
    10181040        r = get_attnames(table)
    10191041        self.assertIsInstance(r, dict)
    1020         self.assertEqual(r, {'a': 'int', 'c': 'int', 'b': 'int',
    1021             'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
    1022             'normal_name': 'int', 'Special Name': 'int',
    1023             'u': 'text', 't': 'text', 'v': 'text',
    1024             'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
     1042        if self.regtypes:
     1043            self.assertEqual(r, {
     1044                'a': 'smallint', 'b': 'integer', 'c': 'bigint',
     1045                'e': 'numeric', 'f': 'real', 'f2': 'double precision',
     1046                'm': 'money', 'normal_name': 'smallint',
     1047                'Special Name': 'smallint', 'u': 'character',
     1048                't': 'text', 'v': 'character varying', 'y': 'smallint',
     1049                'x': 'smallint', 'z': 'smallint', 'oid': 'oid'})
     1050        else:
     1051            self.assertEqual(r, {'a': 'int', 'b': 'int', 'c': 'int',
     1052                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
     1053                'normal_name': 'int', 'Special Name': 'int',
     1054                'u': 'text', 't': 'text', 'v': 'text',
     1055                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
    10251056
    10261057    def testGetAttnamesWithRegtypes(self):
     
    10311062        use_regtypes = self.db.use_regtypes
    10321063        regtypes = use_regtypes()
    1033         self.assertFalse(regtypes)
     1064        self.assertEqual(regtypes, self.regtypes)
    10341065        use_regtypes(True)
    10351066        try:
     
    10421073            gamma='character', tau='text', v='character varying'))
    10431074
     1075    def testGetAttnamesWithoutRegtypes(self):
     1076        get_attnames = self.db.get_attnames
     1077        self.createTable('test_table',
     1078            ' n int, alpha smallint, beta bool,'
     1079            ' gamma char(5), tau text, v varchar(3)')
     1080        use_regtypes = self.db.use_regtypes
     1081        regtypes = use_regtypes()
     1082        self.assertEqual(regtypes, self.regtypes)
     1083        use_regtypes(False)
     1084        try:
     1085            r = get_attnames("test_table")
     1086            self.assertIsInstance(r, dict)
     1087        finally:
     1088            use_regtypes(regtypes)
     1089        self.assertEqual(r, dict(
     1090            n='int', alpha='int', beta='bool',
     1091            gamma='text', tau='text', v='text'))
     1092
    10441093    def testGetAttnamesIsCached(self):
    10451094        get_attnames = self.db.get_attnames
     1095        int_type = 'integer' if self.regtypes else 'int'
     1096        text_type = 'text'
    10461097        query = self.db.query
    10471098        self.createTable('test_table', 'col int')
    10481099        r = get_attnames("test_table")
    10491100        self.assertIsInstance(r, dict)
    1050         self.assertEqual(r, dict(col='int'))
     1101        self.assertEqual(r, dict(col=int_type))
    10511102        query("alter table test_table alter column col type text")
    10521103        query("alter table test_table add column col2 int")
    10531104        r = get_attnames("test_table")
    1054         self.assertEqual(r, dict(col='int'))
     1105        self.assertEqual(r, dict(col=int_type))
    10551106        r = get_attnames("test_table", flush=True)
    1056         self.assertEqual(r, dict(col='text', col2='int'))
     1107        self.assertEqual(r, dict(col=text_type, col2=int_type))
    10571108        query("alter table test_table drop column col2")
    10581109        r = get_attnames("test_table")
    1059         self.assertEqual(r, dict(col='text', col2='int'))
     1110        self.assertEqual(r, dict(col=text_type, col2=int_type))
    10601111        r = get_attnames("test_table", flush=True)
    1061         self.assertEqual(r, dict(col='text'))
     1112        self.assertEqual(r, dict(col=text_type))
    10621113        query("alter table test_table drop column col")
    10631114        r = get_attnames("test_table")
    1064         self.assertEqual(r, dict(col='text'))
     1115        self.assertEqual(r, dict(col=text_type))
    10651116        r = get_attnames("test_table", flush=True)
    10661117        self.assertEqual(r, dict())
     
    10701121        r = get_attnames('test', flush=True)
    10711122        self.assertIsInstance(r, OrderedDict)
    1072         self.assertEqual(r, OrderedDict([
    1073             ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
    1074             ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
    1075             ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
     1123        if self.regtypes:
     1124            self.assertEqual(r, OrderedDict([
     1125                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
     1126                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
     1127                ('m', 'money'), ('v4', 'character varying'),
     1128                ('c4', 'character'), ('t', 'text')]))
     1129        else:
     1130            self.assertEqual(r, OrderedDict([
     1131                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
     1132                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
     1133                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
    10761134        if OrderedDict is not dict:
    10771135            r = ' '.join(list(r.keys()))
     
    10831141        r = get_attnames(table)
    10841142        self.assertIsInstance(r, OrderedDict)
    1085         self.assertEqual(r, OrderedDict([
    1086             ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
    1087             ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
     1143        if self.regtypes:
     1144            self.assertEqual(r, OrderedDict([
     1145                ('n', 'integer'), ('alpha', 'smallint'),
     1146                ('v', 'character varying'), ('gamma', 'character'),
     1147                ('tau', 'text'), ('beta', 'boolean')]))
     1148        else:
     1149            self.assertEqual(r, OrderedDict([
     1150                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
     1151                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
    10881152        if OrderedDict is not dict:
    10891153            r = ' '.join(list(r.keys()))
     
    10971161        r = get_attnames('test', flush=True)
    10981162        self.assertIsInstance(r, AttrDict)
    1099         self.assertEqual(r, AttrDict([
    1100             ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
    1101             ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
    1102             ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
     1163        if self.regtypes:
     1164            self.assertEqual(r, AttrDict([
     1165                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
     1166                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
     1167                ('m', 'money'), ('v4', 'character varying'),
     1168                ('c4', 'character'), ('t', 'text')]))
     1169        else:
     1170            self.assertEqual(r, AttrDict([
     1171                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
     1172                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
     1173                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
    11031174        r = ' '.join(list(r.keys()))
    11041175        self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
     
    11091180        r = get_attnames(table)
    11101181        self.assertIsInstance(r, AttrDict)
    1111         self.assertEqual(r, AttrDict([
    1112             ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
    1113             ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
     1182        if self.regtypes:
     1183            self.assertEqual(r, AttrDict([
     1184                ('n', 'integer'), ('alpha', 'smallint'),
     1185                ('v', 'character varying'), ('gamma', 'character'),
     1186                ('tau', 'text'), ('beta', 'boolean')]))
     1187        else:
     1188            self.assertEqual(r, AttrDict([
     1189                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
     1190                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
    11141191        r = ' '.join(list(r.keys()))
    11151192        self.assertEqual(r, 'n alpha v gamma tau beta')
     
    29703047            ' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]')
    29713048        r = self.db.get_attnames('arraytest')
    2972         self.assertEqual(r, dict(id='int', i2='int[]', i4='int[]', i8='int[]',
    2973             d='num[]', f4='float[]', f8='float[]', m='money[]',
    2974             b='bool[]', v4='text[]', c4='text[]', t='text[]'))
     3049        if self.regtypes:
     3050            self.assertEqual(r, dict(
     3051                id='smallint', i2='smallint[]', i4='integer[]', i8='bigint[]',
     3052                d='numeric[]', f4='real[]', f8='double precision[]',
     3053                m='money[]', b='boolean[]',
     3054                v4='character varying[]', c4='character[]', t='text[]'))
     3055        else:
     3056            self.assertEqual(r, dict(
     3057                id='int', i2='int[]', i4='int[]', i8='int[]',
     3058                d='num[]', f4='float[]', f8='float[]', m='money[]',
     3059                b='bool[]', v4='text[]', c4='text[]', t='text[]'))
    29753060        decimal = pg.get_decimal()
    29763061        if decimal is Decimal:
     
    30053090        self.assertEqual(r, data)
    30063091
    3007     def testArrayInput(self):
     3092    def testArrayLiteral(self):
    30083093        insert = self.db.insert
    30093094        self.createTable('arraytest', 'i int[], t text[]', oids=True)
     
    30163101        self.assertEqual(r['i'], [1, 2, 3])
    30173102        self.assertEqual(r['t'], ['a', 'b', 'c'])
    3018         r = dict(i="[1, 2, 3]", t="['a', 'b', 'c']")
    3019         self.db.insert('arraytest', r)
    3020         self.assertEqual(r['i'], [1, 2, 3])
    3021         self.assertEqual(r['t'], ['a', 'b', 'c'])
    3022         r = dict(i="array[1, 2, 3]", t="array['a', 'b', 'c']")
    3023         self.db.insert('arraytest', r)
    3024         self.assertEqual(r['i'], [1, 2, 3])
    3025         self.assertEqual(r['t'], ['a', 'b', 'c'])
    3026         r = dict(i="ARRAY[1, 2, 3]", t="ARRAY['a', 'b', 'c']")
     3103        L = pg._Literal
     3104        r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']"))
    30273105        self.db.insert('arraytest', r)
    30283106        self.assertEqual(r['i'], [1, 2, 3])
    30293107        self.assertEqual(r['t'], ['a', 'b', 'c'])
    30303108        r = dict(i="1, 2, 3", t="'a', 'b', 'c'")
    3031         self.assertRaises(ValueError, self.db.insert, 'arraytest', r)
     3109        self.assertRaises(pg.ProgrammingError, self.db.insert, 'arraytest', r)
    30323110
    30333111    def testArrayOfIds(self):
    30343112        self.createTable('arraytest', 'c cid[], o oid[], x xid[]', oids=True)
    30353113        r = self.db.get_attnames('arraytest')
    3036         self.assertEqual(r, dict(oid='int', c='int[]', o='int[]', x='int[]'))
     3114        if self.regtypes:
     3115            self.assertEqual(r, dict(
     3116                oid='oid', c='cid[]', o='oid[]', x='xid[]'))
     3117        else:
     3118            self.assertEqual(r, dict(
     3119                oid='int', c='int[]', o='int[]', x='int[]'))
    30373120        data = dict(c=[11, 12, 13], o=[21, 22, 23], x=[31, 32, 33])
    30383121        r = data.copy()
     
    31263209            self.fail(str(error))
    31273210        r = self.db.get_attnames('arraytest')
    3128         self.assertEqual(r['data'], 'json[]')
     3211        self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]')
    31293212        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
    31303213        jsondecode = pg.get_jsondecode()
     
    31663249        self.db.get('arraytest', r)
    31673250        self.assertEqual(r['data'], data)
     3251
     3252    def testInsertUpdateGetRecord(self):
     3253        query = self.db.query
     3254        query('create type test_person_type as'
     3255            ' (name varchar, age smallint, married bool,'
     3256              ' weight real, salary money)')
     3257        self.addCleanup(query, 'drop type test_person_type')
     3258        self.createTable('test_person', 'person test_person_type',
     3259            temporary=False, oids=True)
     3260        attnames = self.db.get_attnames('test_person')
     3261        self.assertEqual(len(attnames), 2)
     3262        self.assertIn('oid', attnames)
     3263        self.assertIn('person', attnames)
     3264        person_typ = attnames['person']
     3265        if self.regtypes:
     3266            self.assertEqual(person_typ, 'test_person_type')
     3267        else:
     3268            self.assertEqual(person_typ, 'record')
     3269        if self.regtypes:
     3270            self.assertEqual(person_typ.attnames,
     3271                dict(name='character varying', age='smallint',
     3272                    married='boolean', weight='real', salary='money'))
     3273        else:
     3274            self.assertEqual(person_typ.attnames,
     3275                dict(name='text', age='int', married='bool',
     3276                    weight='float', salary='money'))
     3277        decimal = pg.get_decimal()
     3278        if pg.get_bool():
     3279            bool_class = bool
     3280            t, f = True, False
     3281        else:
     3282            bool_class = str
     3283            t, f = 't', 'f'
     3284        person = ('John Doe', 61, t, 99.5, decimal('93456.75'))
     3285        r = self.db.insert('test_person', None, person=person)
     3286        p = r['person']
     3287        self.assertIsInstance(p, tuple)
     3288        self.assertEqual(p, person)
     3289        self.assertEqual(p.name, 'John Doe')
     3290        self.assertIsInstance(p.name, str)
     3291        self.assertIsInstance(p.age, int)
     3292        self.assertIsInstance(p.married, bool_class)
     3293        self.assertIsInstance(p.weight, float)
     3294        self.assertIsInstance(p.salary, decimal)
     3295        person = ('Jane Roe', 59, f, 64.5, decimal('96543.25'))
     3296        r['person'] = person
     3297        self.db.update('test_person', r)
     3298        p = r['person']
     3299        self.assertIsInstance(p, tuple)
     3300        self.assertEqual(p, person)
     3301        self.assertEqual(p.name, 'Jane Roe')
     3302        self.assertIsInstance(p.name, str)
     3303        self.assertIsInstance(p.age, int)
     3304        self.assertIsInstance(p.married, bool_class)
     3305        self.assertIsInstance(p.weight, float)
     3306        self.assertIsInstance(p.salary, decimal)
     3307        r['person'] = None
     3308        self.db.get('test_person', r)
     3309        p = r['person']
     3310        self.assertIsInstance(p, tuple)
     3311        self.assertEqual(p, person)
     3312        self.assertEqual(p.name, 'Jane Roe')
     3313        self.assertIsInstance(p.name, str)
     3314        self.assertIsInstance(p.age, int)
     3315        self.assertIsInstance(p.married, bool_class)
     3316        self.assertIsInstance(p.weight, float)
     3317        self.assertIsInstance(p.salary, decimal)
     3318        person = (None,) * 5
     3319        r = self.db.insert('test_person', None, person=person)
     3320        p = r['person']
     3321        self.assertIsInstance(p, tuple)
     3322        self.assertIsNone(p.name)
     3323        self.assertIsNone(p.age)
     3324        self.assertIsNone(p.married)
     3325        self.assertIsNone(p.weight)
     3326        self.assertIsNone(p.salary)
     3327        r['person'] = None
     3328        self.db.get('test_person', r)
     3329        p = r['person']
     3330        self.assertIsInstance(p, tuple)
     3331        self.assertIsNone(p.name)
     3332        self.assertIsNone(p.age)
     3333        self.assertIsNone(p.married)
     3334        self.assertIsNone(p.weight)
     3335        self.assertIsNone(p.salary)
     3336        r = self.db.insert('test_person', None, person=None)
     3337        self.assertIsNone(r['person'])
     3338        r['person'] = None
     3339        self.db.get('test_person', r)
     3340        self.assertIsNone(r['person'])
     3341
     3342    def testRecordInsertBytea(self):
     3343        query = self.db.query
     3344        query('create type test_person_type as'
     3345            ' (name text, picture bytea)')
     3346        self.addCleanup(query, 'drop type test_person_type')
     3347        self.createTable('test_person', 'person test_person_type',
     3348            temporary=False, oids=True)
     3349        person_typ = self.db.get_attnames('test_person')['person']
     3350        self.assertEqual(person_typ.attnames,
     3351            dict(name='text', picture='bytea'))
     3352        person = ('John Doe', b'O\x00ps\xff!')
     3353        r = self.db.insert('test_person', None, person=person)
     3354        p = r['person']
     3355        self.assertIsInstance(p, tuple)
     3356        self.assertEqual(p, person)
     3357        self.assertEqual(p.name, 'John Doe')
     3358        self.assertIsInstance(p.name, str)
     3359        self.assertEqual(p.picture, person[1])
     3360        self.assertIsInstance(p.picture, bytes)
     3361
     3362    def testRecordInsertJson(self):
     3363        query = self.db.query
     3364        try:
     3365            query('create type test_person_type as'
     3366                ' (name text, data json)')
     3367        except pg.ProgrammingError as error:
     3368            if self.db.server_version < 90200:
     3369                self.skipTest('database does not support json')
     3370            self.fail(str(error))
     3371        self.addCleanup(query, 'drop type test_person_type')
     3372        self.createTable('test_person', 'person test_person_type',
     3373            temporary=False, oids=True)
     3374        person_typ = self.db.get_attnames('test_person')['person']
     3375        self.assertEqual(person_typ.attnames,
     3376            dict(name='text', data='json'))
     3377        person = ('John Doe', dict(age=61, married=True, weight=99.5))
     3378        r = self.db.insert('test_person', None, person=person)
     3379        p = r['person']
     3380        self.assertIsInstance(p, tuple)
     3381        if pg.get_jsondecode() is None:
     3382            p = p._replace(data=json.loads(p.data))
     3383        self.assertEqual(p, person)
     3384        self.assertEqual(p.name, 'John Doe')
     3385        self.assertIsInstance(p.name, str)
     3386        self.assertEqual(p.data, person[1])
     3387        self.assertIsInstance(p.data, dict)
     3388
     3389    def testRecordLiteral(self):
     3390        query = self.db.query
     3391        query('create type test_person_type as'
     3392            ' (name varchar, age smallint)')
     3393        self.addCleanup(query, 'drop type test_person_type')
     3394        self.createTable('test_person', 'person test_person_type',
     3395            temporary=False, oids=True)
     3396        person_typ = self.db.get_attnames('test_person')['person']
     3397        if self.regtypes:
     3398            self.assertEqual(person_typ, 'test_person_type')
     3399        else:
     3400            self.assertEqual(person_typ, 'record')
     3401        if self.regtypes:
     3402            self.assertEqual(person_typ.attnames,
     3403                dict(name='character varying', age='smallint'))
     3404        else:
     3405            self.assertEqual(person_typ.attnames,
     3406                dict(name='text', age='int'))
     3407        person = pg._Literal("('John Doe', 61)")
     3408        r = self.db.insert('test_person', None, person=person)
     3409        p = r['person']
     3410        self.assertIsInstance(p, tuple)
     3411        self.assertEqual(p.name, 'John Doe')
     3412        self.assertIsInstance(p.name, str)
     3413        self.assertEqual(p.age, 61)
     3414        self.assertIsInstance(p.age, int)
    31683415
    31693416    def testNotificationHandler(self):
     
    32563503        cls.set_option('namedresult', None)
    32573504        cls.set_option('jsondecode', None)
     3505        cls.regtypes = not DB().use_regtypes()
    32583506        super(TestDBClassNonStdOpts, cls).setUpClass()
    32593507
  • trunk/tests/test_classic_functions.py

    r791 r793  
    217217        self.assertRaises(TypeError, f, None)
    218218        self.assertRaises(TypeError, f, '{}', 1)
    219         self.assertRaises(TypeError, f, '{}', ',',)
     219        self.assertRaises(TypeError, f, '{}', b',',)
    220220        self.assertRaises(TypeError, f, '{}', None, None)
    221221        self.assertRaises(TypeError, f, '{}', None, 1)
    222         self.assertRaises(TypeError, f, '{}', None, '')
    223         self.assertRaises(ValueError, f, '{}', None, '\\')
    224         self.assertRaises(ValueError, f, '{}', None, '{')
    225         self.assertRaises(ValueError, f, '{}', None, '}')
    226         self.assertRaises(TypeError, f, '{}', None, ',;')
     222        self.assertRaises(TypeError, f, '{}', None, b'')
     223        self.assertRaises(ValueError, f, '{}', None, b'\\')
     224        self.assertRaises(ValueError, f, '{}', None, b'{')
     225        self.assertRaises(ValueError, f, '{}', None, b'}')
     226        self.assertRaises(TypeError, f, '{}', None, b',;')
    227227        self.assertEqual(f('{}'), [])
    228228        self.assertEqual(f('{}', None), [])
     
    489489        self.assertRaises(TypeError, f, None)
    490490        self.assertRaises(TypeError, f, '()', 1)
    491         self.assertRaises(TypeError, f, '()', ',',)
     491        self.assertRaises(TypeError, f, '()', b',',)
    492492        self.assertRaises(TypeError, f, '()', None, None)
    493493        self.assertRaises(TypeError, f, '()', None, 1)
    494         self.assertRaises(TypeError, f, '()', None, '')
    495         self.assertRaises(ValueError, f, '()', None, '\\')
    496         self.assertRaises(ValueError, f, '()', None, '(')
    497         self.assertRaises(ValueError, f, '()', None, ')')
    498         self.assertRaises(TypeError, f, '{}', None, ',;')
     494        self.assertRaises(TypeError, f, '()', None, b'')
     495        self.assertRaises(ValueError, f, '()', None, b'\\')
     496        self.assertRaises(ValueError, f, '()', None, b'(')
     497        self.assertRaises(ValueError, f, '()', None, b')')
     498        self.assertRaises(TypeError, f, '{}', None, b',;')
    499499        self.assertEqual(f('()'), (None,))
    500500        self.assertEqual(f('()', None), (None,))
Note: See TracChangeset for help on using the changeset viewer.