Changeset 814 for trunk/pg.py


Ignore:
Timestamp:
Feb 3, 2016, 3:23:20 PM (3 years ago)
Author:
cito
Message:

Add typecasting of dates, times, timestamps, intervals

So far, PyGreSQL has returned these types only as strings (in various
formats depending on the DateStyle? setting) and left it to the user
to parse and interpret the strings. These types are now properly cast
into the corresponding detetime types of Python, and this works with
any setting of DatesStyle?, even if you change DateStyle? in the middle
of a database session.

To implement this, a fast method for getting the datestyle (cached and
without roundtrip to the database) has been added. Also, the typecast
mechanism has been extended so that typecast functions can optionally
also take the connection as argument.

The date and time typecast functions have been implemented in Python
using the new typecast registry and added to both pg and pgdb. Some
duplication of code in the two modules was unavoidable, since we don't
want the modules to be dependent of each other or install additional
helper modules. One day we might want to change this, put everything
in one package and factor out some of the functionality.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/pg.py

    r799 r814  
    5252except NameError:  # Python >= 3.0
    5353    basestring = (str, bytes)
     54
     55
     56# Auxiliary classes and functions that are independent from a DB connection:
    5457
    5558try:
     
    138141            raise TypeError('This object is read-only')
    139142
    140 
    141 # Auxiliary classes and functions that are independent from a DB connection:
     143try:
     144    from inspect import signature
     145except ImportError:  # Python < 3.3
     146    from inspect import getargspec
     147
     148    get_args = lambda func: getargspec(func).args
     149else:
     150    get_args = lambda func: list(signature(func).parameters)
     151
     152try:
     153    if datetime.strptime('+0100', '%z') is None:
     154        raise ValueError
     155except ValueError:  # Python < 3.2
     156    timezones = None
     157else:
     158    # time zones used in Postgres timestamptz output
     159    timezones = dict(CET='+0100', EET='+0200', EST='-0500',
     160        GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
     161        UCT='+0000', UTC='+0000', WET='+0000')
     162
    142163
    143164def _oid_key(table):
     
    588609
    589610
     611def cast_date(value, connection):
     612    """Cast a date value."""
     613    # The output format depends on the server setting DateStyle.  The default
     614    # setting ISO and the setting for German are actually unambiguous.  The
     615    # order of days and months in the other two settings is however ambiguous,
     616    # so at least here we need to consult the setting to properly parse values.
     617    if value == '-infinity':
     618        return date.min
     619    if value == 'infinity':
     620        return date.max
     621    value = value.split()
     622    if value[-1] == 'BC':
     623        return date.min
     624    value = value[0]
     625    if len(value) > 10:
     626        return date.max
     627    fmt = connection.date_format()
     628    return datetime.strptime(value, fmt).date()
     629
     630
     631def cast_time(value):
     632    """Cast a time value."""
     633    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
     634    return datetime.strptime(value, fmt).time()
     635
     636
     637_re_timezone = regex('(.*)([+-].*)')
     638
     639
     640def cast_timetz(value):
     641    """Cast a timetz value."""
     642    tz = _re_timezone.match(value)
     643    if tz:
     644        value, tz = tz.groups()
     645    else:
     646        tz = '+0000'
     647    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
     648    if timezones:
     649        if tz.startswith(('+', '-')):
     650            if len(tz) < 5:
     651                tz += '00'
     652            else:
     653                tz = tz.replace(':', '')
     654        elif tz in timezones:
     655            tz = timezones[tz]
     656        else:
     657            tz = '+0000'
     658        value += tz
     659        fmt += '%z'
     660    return datetime.strptime(value, fmt).timetz()
     661
     662
     663def cast_timestamp(value, connection):
     664    """Cast a timestamp value."""
     665    if value == '-infinity':
     666        return datetime.min
     667    if value == 'infinity':
     668        return datetime.max
     669    value = value.split()
     670    if value[-1] == 'BC':
     671        return datetime.min
     672    fmt = connection.date_format()
     673    if fmt.endswith('-%Y') and len(value) > 2:
     674        value = value[1:5]
     675        if len(value[3]) > 4:
     676            return datetime.max
     677        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
     678            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
     679    else:
     680        if len(value[0]) > 10:
     681            return datetime.max
     682        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
     683    return datetime.strptime(' '.join(value), ' '.join(fmt))
     684
     685
     686def cast_timestamptz(value, connection):
     687    """Cast a timestamptz value."""
     688    if value == '-infinity':
     689        return datetime.min
     690    if value == 'infinity':
     691        return datetime.max
     692    value = value.split()
     693    if value[-1] == 'BC':
     694        return datetime.min
     695    fmt = connection.date_format()
     696    if fmt.endswith('-%Y') and len(value) > 2:
     697        value = value[1:]
     698        if len(value[3]) > 4:
     699            return datetime.max
     700        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
     701            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
     702        value, tz = value[:-1], value[-1]
     703    else:
     704        if fmt.startswith('%Y-'):
     705            tz = _re_timezone.match(value[1])
     706            if tz:
     707                value[1], tz = tz.groups()
     708            else:
     709                tz = '+0000'
     710        else:
     711            value, tz = value[:-1], value[-1]
     712        if len(value[0]) > 10:
     713            return datetime.max
     714        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
     715    if timezones:
     716        if tz.startswith(('+', '-')):
     717            if len(tz) < 5:
     718                tz += '00'
     719            else:
     720                tz = tz.replace(':', '')
     721        elif tz in timezones:
     722            tz = timezones[tz]
     723        else:
     724            tz = '+0000'
     725        value.append(tz)
     726        fmt.append('%z')
     727    return datetime.strptime(' '.join(value), ' '.join(fmt))
     728
     729_re_interval_sql_standard = regex(
     730    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
     731    '(?:([+-]?[0-9]+)(?!:) ?)?'
     732    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
     733
     734_re_interval_postgres = regex(
     735    '(?:([+-]?[0-9]+) ?years? ?)?'
     736    '(?:([+-]?[0-9]+) ?mons? ?)?'
     737    '(?:([+-]?[0-9]+) ?days? ?)?'
     738    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
     739
     740_re_interval_postgres_verbose = regex(
     741    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
     742    '(?:([+-]?[0-9]+) ?mons? ?)?'
     743    '(?:([+-]?[0-9]+) ?days? ?)?'
     744    '(?:([+-]?[0-9]+) ?hours? ?)?'
     745    '(?:([+-]?[0-9]+) ?mins? ?)?'
     746    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
     747
     748_re_interval_iso_8601 = regex(
     749    'P(?:([+-]?[0-9]+)Y)?'
     750    '(?:([+-]?[0-9]+)M)?'
     751    '(?:([+-]?[0-9]+)D)?'
     752    '(?:T(?:([+-]?[0-9]+)H)?'
     753    '(?:([+-]?[0-9]+)M)?'
     754    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
     755
     756
     757def cast_interval(value):
     758    """Cast an interval value."""
     759    # The output format depends on the server setting IntervalStyle, but it's
     760    # not necessary to consult this setting to parse it.  It's faster to just
     761    # check all possible formats, and there is no ambiguity here.
     762    m = _re_interval_iso_8601.match(value)
     763    if m:
     764        m = [d or '0' for d in m.groups()]
     765        secs_ago = m.pop(5) == '-'
     766        m = [int(d) for d in m]
     767        years, mons, days, hours, mins, secs, usecs = m
     768        if secs_ago:
     769            secs = -secs
     770            usecs = -usecs
     771    else:
     772        m = _re_interval_postgres_verbose.match(value)
     773        if m:
     774            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
     775            secs_ago = m.pop(5) == '-'
     776            m = [-int(d) for d in m] if ago else [int(d) for d in m]
     777            years, mons, days, hours, mins, secs, usecs = m
     778            if secs_ago:
     779                secs = - secs
     780                usecs = -usecs
     781        else:
     782            m = _re_interval_postgres.match(value)
     783            if m and any(m.groups()):
     784                m = [d or '0' for d in m.groups()]
     785                hours_ago = m.pop(3) == '-'
     786                m = [int(d) for d in m]
     787                years, mons, days, hours, mins, secs, usecs = m
     788                if hours_ago:
     789                    hours = -hours
     790                    mins = -mins
     791                    secs = -secs
     792                    usecs = -usecs
     793            else:
     794                m = _re_interval_sql_standard.match(value)
     795                if m and any(m.groups()):
     796                    m = [d or '0' for d in m.groups()]
     797                    years_ago = m.pop(0) == '-'
     798                    hours_ago = m.pop(3) == '-'
     799                    m = [int(d) for d in m]
     800                    years, mons, days, hours, mins, secs, usecs = m
     801                    if years_ago:
     802                        years = -years
     803                        mons = -mons
     804                    if hours_ago:
     805                        hours = -hours
     806                        mins = -mins
     807                        secs = -secs
     808                        usecs = -usecs
     809                else:
     810                    raise ValueError('Cannot parse interval: %s' % value)
     811    days += 365 * years + 30 * mons
     812    return timedelta(days=days, hours=hours, minutes=mins,
     813        seconds=secs, microseconds=usecs)
     814
     815
    590816class Typecasts(dict):
    591817    """Dictionary mapping database types to typecast functions.
     
    610836        'float4': float, 'float8': float,
    611837        'numeric': cast_num, 'money': cast_money,
     838        'date': cast_date, 'interval': cast_interval,
     839        'time': cast_time, 'timetz': cast_timetz,
     840        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
    612841        'int2vector': cast_int2vector,
    613842        'anyarray': cast_array, 'record': cast_record}
     843
     844    connection = None  # will be set in a connection specific instance
    614845
    615846    def __missing__(self, typ):
     
    624855        if cast:
    625856            # store default for faster access
     857            cast = self._add_connection(cast)
    626858            self[typ] = cast
    627859        elif typ.startswith('_'):
     
    637869                self[typ] = cast
    638870        return cast
     871
     872    @staticmethod
     873    def _needs_connection(func):
     874        """Check if a typecast function needs a connection argument."""
     875        try:
     876            args = get_args(func)
     877        except (TypeError, ValueError):
     878            return False
     879        else:
     880            return 'connection' in args[1:]
     881
     882    def _add_connection(self, cast):
     883        """Add a connection argument to the typecast function if necessary."""
     884        if not self.connection or not self._needs_connection(cast):
     885            return cast
     886        connection = self.connection
     887        return lambda value: cast(value, connection=connection)
    639888
    640889    def get(self, typ, default=None):
     
    654903                raise TypeError("Cast parameter must be callable")
    655904            for t in typ:
    656                 self[t] = cast
     905                self[t] = self._add_connection(cast)
    657906                self.pop('_%s' % t, None)
    658907
     
    699948        return {}
    700949
     950    def dateformat(self):
     951        """Return the current date format.
     952
     953        This method will be replaced with the dateformat() method of DbTypes.
     954        """
     955        return '%Y-%m-%d'
     956
    701957    def create_array_cast(self, cast):
    702958        """Create an array typecast for the given base cast."""
     
    7581014        """Initialize type cache for connection."""
    7591015        super(DbTypes, self).__init__()
     1016        self._regtypes = False
    7601017        self._get_attnames = db.get_attnames
     1018        self._typecasts = Typecasts()
     1019        self._typecasts.get_attnames = self.get_attnames
     1020        self._typecasts.connection = db
    7611021        db = db.db
    7621022        self.query = db.query
    7631023        self.escape_string = db.escape_string
    764         self._typecasts = Typecasts()
    765         self._typecasts.get_attnames = self.get_attnames
    766         self._regtypes = False
    7671024
    7681025    def add(self, oid, pgtype, regtype,
Note: See TracChangeset for help on using the changeset viewer.