source: trunk/pgdb.py @ 814

Last change on this file since 814 was 814, checked in by cito, 4 years ago

Add typecasting of dates, times, timestamps, intervals

So far, PyGreSQL has returned these types only as strings (in various
formats depending on the DateStyle? setting) and left it to the user
to parse and interpret the strings. These types are now properly cast
into the corresponding detetime types of Python, and this works with
any setting of DatesStyle?, even if you change DateStyle? in the middle
of a database session.

To implement this, a fast method for getting the datestyle (cached and
without roundtrip to the database) has been added. Also, the typecast
mechanism has been extended so that typecast functions can optionally
also take the connection as argument.

The date and time typecast functions have been implemented in Python
using the new typecast registry and added to both pg and pgdb. Some
duplication of code in the two modules was unavoidable, since we don't
want the modules to be dependent of each other or install additional
helper modules. One day we might want to change this, put everything
in one package and factor out some of the functionality.

  • Property svn:keywords set to Id
File size: 52.3 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 814 2016-02-03 20:23:20Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from re import compile as regex
76from json import loads as jsondecode, dumps as jsonencode
77
78try:
79    long
80except NameError:  # Python >= 3.0
81    long = int
82
83try:
84    unicode
85except NameError:  # Python >= 3.0
86    unicode = str
87
88try:
89    basestring
90except NameError:  # Python >= 3.0
91    basestring = (str, bytes)
92
93from collections import Iterable
94
95
96### Module Constants
97
98# compliant with DB API 2.0
99apilevel = '2.0'
100
101# module may be shared, but not connections
102threadsafety = 1
103
104# this module use extended python format codes
105paramstyle = 'pyformat'
106
107# shortcut methods have been excluded from DB API 2 and
108# are not recommended by the DB SIG, but they can be handy
109shortcutmethods = 1
110
111
112### Internal Type Handling
113
114try:
115    from inspect import signature
116except ImportError:  # Python < 3.3
117    from inspect import getargspec
118
119    get_args = lambda func: getargspec(func).args
120else:
121    get_args = lambda func: list(signature(func).parameters)
122
123try:
124    if datetime.strptime('+0100', '%z') is None:
125        raise ValueError
126except ValueError:  # Python < 3.2
127    timezones = None
128else:
129    # time zones used in Postgres timestamptz output
130    timezones = dict(CET='+0100', EET='+0200', EST='-0500',
131        GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
132        UCT='+0000', UTC='+0000', WET='+0000')
133
134
135def decimal_type(decimal_type=None):
136    """Get or set global type to be used for decimal values.
137
138    Note that connections cache cast functions. To be sure a global change
139    is picked up by a running connection, call con.type_cache.reset_typecast().
140    """
141    global Decimal
142    if decimal_type is not None:
143        Decimal = decimal_type
144        set_typecast('numeric', decimal_type)
145    return Decimal
146
147
148def cast_bool(value):
149    """Cast boolean value in database format to bool."""
150    if value:
151        return value[0] in ('t', 'T')
152
153
154def cast_money(value):
155    """Cast money value in database format to Decimal."""
156    if value:
157        value = value.replace('(', '-')
158        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
159
160
161def cast_int2vector(value):
162    """Cast an int2vector value."""
163    return [int(v) for v in value.split()]
164
165
166def cast_date(value, connection):
167    """Cast a date value."""
168    # The output format depends on the server setting DateStyle.  The default
169    # setting ISO and the setting for German are actually unambiguous.  The
170    # order of days and months in the other two settings is however ambiguous,
171    # so at least here we need to consult the setting to properly parse values.
172    if value == '-infinity':
173        return date.min
174    if value == 'infinity':
175        return date.max
176    value = value.split()
177    if value[-1] == 'BC':
178        return date.min
179    value = value[0]
180    if len(value) > 10:
181        return date.max
182    fmt = connection.date_format()
183    return datetime.strptime(value, fmt).date()
184
185
186def cast_time(value):
187    """Cast a time value."""
188    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
189    return datetime.strptime(value, fmt).time()
190
191
192_re_timezone = regex('(.*)([+-].*)')
193
194
195def cast_timetz(value):
196    """Cast a timetz value."""
197    tz = _re_timezone.match(value)
198    if tz:
199        value, tz = tz.groups()
200    else:
201        tz = '+0000'
202    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
203    if timezones:
204        if tz.startswith(('+', '-')):
205            if len(tz) < 5:
206                tz += '00'
207            else:
208                tz = tz.replace(':', '')
209        elif tz in timezones:
210            tz = timezones[tz]
211        else:
212            tz = '+0000'
213        value += tz
214        fmt += '%z'
215    return datetime.strptime(value, fmt).timetz()
216
217
218def cast_timestamp(value, connection):
219    """Cast a timestamp value."""
220    if value == '-infinity':
221        return datetime.min
222    if value == 'infinity':
223        return datetime.max
224    value = value.split()
225    if value[-1] == 'BC':
226        return datetime.min
227    fmt = connection.date_format()
228    if fmt.endswith('-%Y') and len(value) > 2:
229        value = value[1:5]
230        if len(value[3]) > 4:
231            return datetime.max
232        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
233            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
234    else:
235        if len(value[0]) > 10:
236            return datetime.max
237        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
238    return datetime.strptime(' '.join(value), ' '.join(fmt))
239
240
241def cast_timestamptz(value, connection):
242    """Cast a timestamptz value."""
243    if value == '-infinity':
244        return datetime.min
245    if value == 'infinity':
246        return datetime.max
247    value = value.split()
248    if value[-1] == 'BC':
249        return datetime.min
250    fmt = connection.date_format()
251    if fmt.endswith('-%Y') and len(value) > 2:
252        value = value[1:]
253        if len(value[3]) > 4:
254            return datetime.max
255        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
256            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
257        value, tz = value[:-1], value[-1]
258    else:
259        if fmt.startswith('%Y-'):
260            tz = _re_timezone.match(value[1])
261            if tz:
262                value[1], tz = tz.groups()
263            else:
264                tz = '+0000'
265        else:
266            value, tz = value[:-1], value[-1]
267        if len(value[0]) > 10:
268            return datetime.max
269        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
270    if timezones:
271        if tz.startswith(('+', '-')):
272            if len(tz) < 5:
273                tz += '00'
274            else:
275                tz = tz.replace(':', '')
276        elif tz in timezones:
277            tz = timezones[tz]
278        else:
279            tz = '+0000'
280        value.append(tz)
281        fmt.append('%z')
282    return datetime.strptime(' '.join(value), ' '.join(fmt))
283
284_re_interval_sql_standard = regex(
285    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
286    '(?:([+-]?[0-9]+)(?!:) ?)?'
287    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
288
289_re_interval_postgres = regex(
290    '(?:([+-]?[0-9]+) ?years? ?)?'
291    '(?:([+-]?[0-9]+) ?mons? ?)?'
292    '(?:([+-]?[0-9]+) ?days? ?)?'
293    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
294
295_re_interval_postgres_verbose = regex(
296    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
297    '(?:([+-]?[0-9]+) ?mons? ?)?'
298    '(?:([+-]?[0-9]+) ?days? ?)?'
299    '(?:([+-]?[0-9]+) ?hours? ?)?'
300    '(?:([+-]?[0-9]+) ?mins? ?)?'
301    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
302
303_re_interval_iso_8601 = regex(
304    'P(?:([+-]?[0-9]+)Y)?'
305    '(?:([+-]?[0-9]+)M)?'
306    '(?:([+-]?[0-9]+)D)?'
307    '(?:T(?:([+-]?[0-9]+)H)?'
308    '(?:([+-]?[0-9]+)M)?'
309    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
310
311
312def cast_interval(value):
313    """Cast an interval value."""
314    # The output format depends on the server setting IntervalStyle, but it's
315    # not necessary to consult this setting to parse it.  It's faster to just
316    # check all possible formats, and there is no ambiguity here.
317    m = _re_interval_iso_8601.match(value)
318    if m:
319        m = [d or '0' for d in m.groups()]
320        secs_ago = m.pop(5) == '-'
321        m = [int(d) for d in m]
322        years, mons, days, hours, mins, secs, usecs = m
323        if secs_ago:
324            secs = -secs
325            usecs = -usecs
326    else:
327        m = _re_interval_postgres_verbose.match(value)
328        if m:
329            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
330            secs_ago = m.pop(5) == '-'
331            m = [-int(d) for d in m] if ago else [int(d) for d in m]
332            years, mons, days, hours, mins, secs, usecs = m
333            if secs_ago:
334                secs = - secs
335                usecs = -usecs
336        else:
337            m = _re_interval_postgres.match(value)
338            if m and any(m.groups()):
339                m = [d or '0' for d in m.groups()]
340                hours_ago = m.pop(3) == '-'
341                m = [int(d) for d in m]
342                years, mons, days, hours, mins, secs, usecs = m
343                if hours_ago:
344                    hours = -hours
345                    mins = -mins
346                    secs = -secs
347                    usecs = -usecs
348            else:
349                m = _re_interval_sql_standard.match(value)
350                if m and any(m.groups()):
351                    m = [d or '0' for d in m.groups()]
352                    years_ago = m.pop(0) == '-'
353                    hours_ago = m.pop(3) == '-'
354                    m = [int(d) for d in m]
355                    years, mons, days, hours, mins, secs, usecs = m
356                    if years_ago:
357                        years = -years
358                        mons = -mons
359                    if hours_ago:
360                        hours = -hours
361                        mins = -mins
362                        secs = -secs
363                        usecs = -usecs
364                else:
365                    raise ValueError('Cannot parse interval: %s' % value)
366    days += 365 * years + 30 * mons
367    return timedelta(days=days, hours=hours, minutes=mins,
368        seconds=secs, microseconds=usecs)
369
370
371class Typecasts(dict):
372    """Dictionary mapping database types to typecast functions.
373
374    The cast functions get passed the string representation of a value in
375    the database which they need to convert to a Python object.  The
376    passed string will never be None since NULL values are already be
377    handled before the cast function is called.
378    """
379
380    # the default cast functions
381    # (str functions are ignored but have been added for faster access)
382    defaults = {'char': str, 'bpchar': str, 'name': str,
383        'text': str, 'varchar': str,
384        'bool': cast_bool, 'bytea': unescape_bytea,
385        'int2': int, 'int4': int, 'serial': int,
386        'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
387        'oid': long, 'oid8': long,
388        'float4': float, 'float8': float,
389        'numeric': Decimal, 'money': cast_money,
390        'date': cast_date, 'interval': cast_interval,
391        'time': cast_time, 'timetz': cast_timetz,
392        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
393        'int2vector': cast_int2vector,
394        'anyarray': cast_array, 'record': cast_record}
395
396    connection = None  # will be set in local connection specific instances
397
398    def __missing__(self, typ):
399        """Create a cast function if it is not cached.
400
401        Note that this class never raises a KeyError,
402        but returns None when no special cast function exists.
403        """
404        if not isinstance(typ, str):
405            raise TypeError('Invalid type: %s' % typ)
406        cast = self.defaults.get(typ)
407        if cast:
408            # store default for faster access
409            cast = self._add_connection(cast)
410            self[typ] = cast
411        elif typ.startswith('_'):
412            # create array cast
413            base_cast = self[typ[1:]]
414            cast = self.create_array_cast(base_cast)
415            if base_cast:
416                # store only if base type exists
417                self[typ] = cast
418        return cast
419
420    @staticmethod
421    def _needs_connection(func):
422        """Check if a typecast function needs a connection argument."""
423        try:
424            args = get_args(func)
425        except (TypeError, ValueError):
426            return False
427        else:
428            return 'connection' in args[1:]
429
430    def _add_connection(self, cast):
431        """Add a connection argument to the typecast function if necessary."""
432        if not self.connection or not self._needs_connection(cast):
433            return cast
434        connection = self.connection
435        return lambda value: cast(value, connection=connection)
436
437    def get(self, typ, default=None):
438        """Get the typecast function for the given database type."""
439        return self[typ] or default
440
441    def set(self, typ, cast):
442        """Set a typecast function for the specified database type(s)."""
443        if isinstance(typ, basestring):
444            typ = [typ]
445        if cast is None:
446            for t in typ:
447                self.pop(t, None)
448                self.pop('_%s' % t, None)
449        else:
450            if not callable(cast):
451                raise TypeError("Cast parameter must be callable")
452            for t in typ:
453                self[t] = self._add_connection(cast)
454                self.pop('_%s' % t, None)
455
456    def reset(self, typ=None):
457        """Reset the typecasts for the specified type(s) to their defaults.
458
459        When no type is specified, all typecasts will be reset.
460        """
461        defaults = self.defaults
462        if typ is None:
463            self.clear()
464            self.update(defaults)
465        else:
466            if isinstance(typ, basestring):
467                typ = [typ]
468            for t in typ:
469                cast = defaults.get(t)
470                if cast:
471                    self[t] = self._add_connection(cast)
472                    t = '_%s' % t
473                    cast = defaults.get(t)
474                    if cast:
475                        self[t] = self._add_connection(cast)
476                    else:
477                        self.pop(t, None)
478                else:
479                    self.pop(t, None)
480                    self.pop('_%s' % t, None)
481
482    def create_array_cast(self, cast):
483        """Create an array typecast for the given base cast."""
484        return lambda v: cast_array(v, cast)
485
486    def create_record_cast(self, name, fields, casts):
487        """Create a named record typecast for the given fields and casts."""
488        record = namedtuple(name, fields)
489        return lambda v: record(*cast_record(v, casts))
490
491
492_typecasts = Typecasts()  # this is the global typecast dictionary
493
494
495def get_typecast(typ):
496    """Get the global typecast function for the given database type(s)."""
497    return _typecasts.get(typ)
498
499
500def set_typecast(typ, cast):
501    """Set a global typecast function for the given database type(s).
502
503    Note that connections cache cast functions. To be sure a global change
504    is picked up by a running connection, call con.type_cache.reset_typecast().
505    """
506    _typecasts.set(typ, cast)
507
508
509def reset_typecast(typ=None):
510    """Reset the global typecasts for the given type(s) to their default.
511
512    When no type is specified, all typecasts will be reset.
513
514    Note that connections cache cast functions. To be sure a global change
515    is picked up by a running connection, call con.type_cache.reset_typecast().
516    """
517    _typecasts.reset(typ)
518
519
520class LocalTypecasts(Typecasts):
521    """Map typecasts, including local composite types, to cast functions."""
522
523    defaults = _typecasts
524
525    connection = None  # will be set in a connection specific instance
526
527    def __missing__(self, typ):
528        """Create a cast function if it is not cached."""
529        if typ.startswith('_'):
530            base_cast = self[typ[1:]]
531            cast = self.create_array_cast(base_cast)
532            if base_cast:
533                self[typ] = cast
534        else:
535            cast = self.defaults.get(typ)
536            if cast:
537                cast = self._add_connection(cast)
538                self[typ] = cast
539            else:
540                fields = self.get_fields(typ)
541                if fields:
542                    casts = [self[field.type] for field in fields]
543                    fields = [field.name for field in fields]
544                    cast = self.create_record_cast(typ, fields, casts)
545                    self[typ] = cast
546        return cast
547
548    def get_fields(self, typ):
549        """Return the fields for the given record type.
550
551        This method will be replaced with a method that looks up the fields
552        using the type cache of the connection.
553        """
554        return []
555
556
557class TypeCode(str):
558    """Class representing the type_code used by the DB-API 2.0.
559
560    TypeCode objects are strings equal to the PostgreSQL type name,
561    but carry some additional information.
562    """
563
564    @classmethod
565    def create(cls, oid, name, len, type, category, delim, relid):
566        """Create a type code for a PostgreSQL data type."""
567        self = cls(name)
568        self.oid = oid
569        self.len = len
570        self.type = type
571        self.category = category
572        self.delim = delim
573        self.relid = relid
574        return self
575
576FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
577
578
579class TypeCache(dict):
580    """Cache for database types.
581
582    This cache maps type OIDs and names to TypeCode strings containing
583    important information on the associated database type.
584    """
585
586    def __init__(self, cnx):
587        """Initialize type cache for connection."""
588        super(TypeCache, self).__init__()
589        self._escape_string = cnx.escape_string
590        self._src = cnx.source()
591        self._typecasts = LocalTypecasts()
592        self._typecasts.get_fields = self.get_fields
593        self._typecasts.connection = cnx
594
595    def __missing__(self, key):
596        """Get the type info from the database if it is not cached."""
597        if isinstance(key, int):
598            oid = key
599        else:
600            if '.' not in key and '"' not in key:
601                key = '"%s"' % key
602            oid = "'%s'::regtype" % self._escape_string(key)
603        try:
604            self._src.execute("SELECT oid, typname,"
605                 " typlen, typtype, typcategory, typdelim, typrelid"
606                " FROM pg_type WHERE oid=%s" % oid)
607        except ProgrammingError:
608            res = None
609        else:
610            res = self._src.fetch(1)
611        if not res:
612            raise KeyError('Type %s could not be found' % key)
613        res = res[0]
614        type_code = TypeCode.create(int(res[0]), res[1],
615            int(res[2]), res[3], res[4], res[5], int(res[6]))
616        self[type_code.oid] = self[str(type_code)] = type_code
617        return type_code
618
619    def get(self, key, default=None):
620        """Get the type even if it is not cached."""
621        try:
622            return self[key]
623        except KeyError:
624            return default
625
626    def get_fields(self, typ):
627        """Get the names and types of the fields of composite types."""
628        if not isinstance(typ, TypeCode):
629            typ = self.get(typ)
630            if not typ:
631                return None
632        if not typ.relid:
633            return None  # this type is not composite
634        self._src.execute("SELECT attname, atttypid"
635            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
636            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
637        return [FieldInfo(name, self.get(int(oid)))
638            for name, oid in self._src.fetch(-1)]
639
640    def get_typecast(self, typ):
641        """Get the typecast function for the given database type."""
642        return self._typecasts.get(typ)
643
644    def set_typecast(self, typ, cast):
645        """Set a typecast function for the specified database type(s)."""
646        self._typecasts.set(typ, cast)
647
648    def reset_typecast(self, typ=None):
649        """Reset the typecast function for the specified database type(s)."""
650        self._typecasts.reset(typ)
651
652    def typecast(self, value, typ):
653        """Cast the given value according to the given database type."""
654        if value is None:
655            # for NULL values, no typecast is necessary
656            return None
657        cast = self.get_typecast(typ)
658        if not cast or cast is str:
659            # no typecast is necessary
660            return value
661        return cast(value)
662
663
664class _quotedict(dict):
665    """Dictionary with auto quoting of its items.
666
667    The quote attribute must be set to the desired quote function.
668    """
669
670    def __getitem__(self, key):
671        return self.quote(super(_quotedict, self).__getitem__(key))
672
673
674### Error messages
675
676def _db_error(msg, cls=DatabaseError):
677    """Return DatabaseError with empty sqlstate attribute."""
678    error = cls(msg)
679    error.sqlstate = None
680    return error
681
682
683def _op_error(msg):
684    """Return OperationalError."""
685    return _db_error(msg, OperationalError)
686
687
688### Cursor Object
689
690class Cursor(object):
691    """Cursor object."""
692
693    def __init__(self, dbcnx):
694        """Create a cursor object for the database connection."""
695        self.connection = self._dbcnx = dbcnx
696        self._cnx = dbcnx._cnx
697        self.type_cache = dbcnx.type_cache
698        self._src = self._cnx.source()
699        # the official attribute for describing the result columns
700        self._description = None
701        if self.row_factory is Cursor.row_factory:
702            # the row factory needs to be determined dynamically
703            self.row_factory = None
704        else:
705            self.build_row_factory = None
706        self.rowcount = -1
707        self.arraysize = 1
708        self.lastrowid = None
709
710    def __iter__(self):
711        """Make cursor compatible to the iteration protocol."""
712        return self
713
714    def __enter__(self):
715        """Enter the runtime context for the cursor object."""
716        return self
717
718    def __exit__(self, et, ev, tb):
719        """Exit the runtime context for the cursor object."""
720        self.close()
721
722    def _quote(self, value):
723        """Quote value depending on its type."""
724        if value is None:
725            return 'NULL'
726        if isinstance(value, (datetime, date, time, timedelta, Json)):
727            value = str(value)
728        if isinstance(value, basestring):
729            if isinstance(value, Binary):
730                value = self._cnx.escape_bytea(value)
731                if bytes is not str:  # Python >= 3.0
732                    value = value.decode('ascii')
733            else:
734                value = self._cnx.escape_string(value)
735            return "'%s'" % value
736        if isinstance(value, float):
737            if isinf(value):
738                return "'-Infinity'" if value < 0 else "'Infinity'"
739            if isnan(value):
740                return "'NaN'"
741            return value
742        if isinstance(value, (int, long, Decimal, Literal)):
743            return value
744        if isinstance(value, list):
745            # Quote value as an ARRAY constructor. This is better than using
746            # an array literal because it carries the information that this is
747            # an array and not a string.  One issue with this syntax is that
748            # you need to add an explicit typecast when passing empty arrays.
749            # The ARRAY keyword is actually only necessary at the top level.
750            q = self._quote
751            return 'ARRAY[%s]' % ','.join(str(q(v)) for v in value)
752        if isinstance(value, tuple):
753            # Quote as a ROW constructor.  This is better than using a record
754            # literal because it carries the information that this is a record
755            # and not a string.  We don't use the keyword ROW in order to make
756            # this usable with the IN syntax as well.  It is only necessary
757            # when the records has a single column which is not really useful.
758            q = self._quote
759            return '(%s)' % ','.join(str(q(v)) for v in value)
760        try:
761            value = value.__pg_repr__()
762        except AttributeError:
763            raise InterfaceError(
764                'Do not know how to adapt type %s' % type(value))
765        if isinstance(value, (tuple, list)):
766            value = self._quote(value)
767        return value
768
769    def _quoteparams(self, string, parameters):
770        """Quote parameters.
771
772        This function works for both mappings and sequences.
773        """
774        if isinstance(parameters, dict):
775            parameters = _quotedict(parameters)
776            parameters.quote = self._quote
777        else:
778            parameters = tuple(map(self._quote, parameters))
779        return string % parameters
780
781    def _make_description(self, info):
782        """Make the description tuple for the given field info."""
783        name, typ, size, mod = info[1:]
784        type_code = self.type_cache[typ]
785        if mod > 0:
786            mod -= 4
787        if type_code == 'numeric':
788            precision, scale = mod >> 16, mod & 0xffff
789            size = precision
790        else:
791            if not size:
792                size = type_code.size
793            if size == -1:
794                size = mod
795            precision = scale = None
796        return CursorDescription(name, type_code,
797            None, size, precision, scale, None)
798
799    @property
800    def description(self):
801        """Read-only attribute describing the result columns."""
802        descr = self._description
803        if self._description is True:
804            make = self._make_description
805            descr = [make(info) for info in self._src.listinfo()]
806            self._description = descr
807        return descr
808
809    @property
810    def colnames(self):
811        """Unofficial convenience method for getting the column names."""
812        return [d[0] for d in self.description]
813
814    @property
815    def coltypes(self):
816        """Unofficial convenience method for getting the column types."""
817        return [d[1] for d in self.description]
818
819    def close(self):
820        """Close the cursor object."""
821        self._src.close()
822        self._description = None
823        self.rowcount = -1
824        self.lastrowid = None
825
826    def execute(self, operation, parameters=None):
827        """Prepare and execute a database operation (query or command)."""
828        # The parameters may also be specified as list of tuples to e.g.
829        # insert multiple rows in a single operation, but this kind of
830        # usage is deprecated.  We make several plausibility checks because
831        # tuples can also be passed with the meaning of ROW constructors.
832        if (parameters and isinstance(parameters, list)
833                and len(parameters) > 1
834                and all(isinstance(p, tuple) for p in parameters)
835                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
836            return self.executemany(operation, parameters)
837        else:
838            # not a list of tuples
839            return self.executemany(operation, [parameters])
840
841    def executemany(self, operation, seq_of_parameters):
842        """Prepare operation and execute it against a parameter sequence."""
843        if not seq_of_parameters:
844            # don't do anything without parameters
845            return
846        self._description = None
847        self.rowcount = -1
848        # first try to execute all queries
849        rowcount = 0
850        sql = "BEGIN"
851        try:
852            if not self._dbcnx._tnx:
853                try:
854                    self._cnx.source().execute(sql)
855                except DatabaseError:
856                    raise  # database provides error message
857                except Exception as err:
858                    raise _op_error("Can't start transaction")
859                self._dbcnx._tnx = True
860            for parameters in seq_of_parameters:
861                sql = operation
862                if parameters:
863                    sql = self._quoteparams(sql, parameters)
864                rows = self._src.execute(sql)
865                if rows:  # true if not DML
866                    rowcount += rows
867                else:
868                    self.rowcount = -1
869        except DatabaseError:
870            raise  # database provides error message
871        except Error as err:
872            raise _db_error(
873                "Error in '%s': '%s' " % (sql, err), InterfaceError)
874        except Exception as err:
875            raise _op_error("Internal error in '%s': %s" % (sql, err))
876        # then initialize result raw count and description
877        if self._src.resulttype == RESULT_DQL:
878            self._description = True  # fetch on demand
879            self.rowcount = self._src.ntuples
880            self.lastrowid = None
881            if self.build_row_factory:
882                self.row_factory = self.build_row_factory()
883        else:
884            self.rowcount = rowcount
885            self.lastrowid = self._src.oidstatus()
886        # return the cursor object, so you can write statements such as
887        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
888        return self
889
890    def fetchone(self):
891        """Fetch the next row of a query result set."""
892        res = self.fetchmany(1, False)
893        try:
894            return res[0]
895        except IndexError:
896            return None
897
898    def fetchall(self):
899        """Fetch all (remaining) rows of a query result."""
900        return self.fetchmany(-1, False)
901
902    def fetchmany(self, size=None, keep=False):
903        """Fetch the next set of rows of a query result.
904
905        The number of rows to fetch per call is specified by the
906        size parameter. If it is not given, the cursor's arraysize
907        determines the number of rows to be fetched. If you set
908        the keep parameter to true, this is kept as new arraysize.
909        """
910        if size is None:
911            size = self.arraysize
912        if keep:
913            self.arraysize = size
914        try:
915            result = self._src.fetch(size)
916        except DatabaseError:
917            raise
918        except Error as err:
919            raise _db_error(str(err))
920        typecast = self.type_cache.typecast
921        return [self.row_factory([typecast(value, typ)
922            for typ, value in zip(self.coltypes, row)]) for row in result]
923
924    def callproc(self, procname, parameters=None):
925        """Call a stored database procedure with the given name.
926
927        The sequence of parameters must contain one entry for each input
928        argument that the procedure expects. The result of the call is the
929        same as this input sequence; replacement of output and input/output
930        parameters in the return value is currently not supported.
931
932        The procedure may also provide a result set as output. These can be
933        requested through the standard fetch methods of the cursor.
934        """
935        n = parameters and len(parameters) or 0
936        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
937        self.execute(query, parameters)
938        return parameters
939
940    def copy_from(self, stream, table,
941            format=None, sep=None, null=None, size=None, columns=None):
942        """Copy data from an input stream to the specified table.
943
944        The input stream can be a file-like object with a read() method or
945        it can also be an iterable returning a row or multiple rows of input
946        on each iteration.
947
948        The format must be text, csv or binary. The sep option sets the
949        column separator (delimiter) used in the non binary formats.
950        The null option sets the textual representation of NULL in the input.
951
952        The size option sets the size of the buffer used when reading data
953        from file-like objects.
954
955        The copy operation can be restricted to a subset of columns. If no
956        columns are specified, all of them will be copied.
957        """
958        binary_format = format == 'binary'
959        try:
960            read = stream.read
961        except AttributeError:
962            if size:
963                raise ValueError("Size must only be set for file-like objects")
964            if binary_format:
965                input_type = bytes
966                type_name = 'byte strings'
967            else:
968                input_type = basestring
969                type_name = 'strings'
970
971            if isinstance(stream, basestring):
972                if not isinstance(stream, input_type):
973                    raise ValueError("The input must be %s" % type_name)
974                if not binary_format:
975                    if isinstance(stream, str):
976                        if not stream.endswith('\n'):
977                            stream += '\n'
978                    else:
979                        if not stream.endswith(b'\n'):
980                            stream += b'\n'
981
982                def chunks():
983                    yield stream
984
985            elif isinstance(stream, Iterable):
986
987                def chunks():
988                    for chunk in stream:
989                        if not isinstance(chunk, input_type):
990                            raise ValueError(
991                                "Input stream must consist of %s" % type_name)
992                        if isinstance(chunk, str):
993                            if not chunk.endswith('\n'):
994                                chunk += '\n'
995                        else:
996                            if not chunk.endswith(b'\n'):
997                                chunk += b'\n'
998                        yield chunk
999
1000            else:
1001                raise TypeError("Need an input stream to copy from")
1002        else:
1003            if size is None:
1004                size = 8192
1005            elif not isinstance(size, int):
1006                raise TypeError("The size option must be an integer")
1007            if size > 0:
1008
1009                def chunks():
1010                    while True:
1011                        buffer = read(size)
1012                        yield buffer
1013                        if not buffer or len(buffer) < size:
1014                            break
1015
1016            else:
1017
1018                def chunks():
1019                    yield read()
1020
1021        if not table or not isinstance(table, basestring):
1022            raise TypeError("Need a table to copy to")
1023        if table.lower().startswith('select'):
1024                raise ValueError("Must specify a table, not a query")
1025        else:
1026            table = '"%s"' % (table,)
1027        operation = ['copy %s' % (table,)]
1028        options = []
1029        params = []
1030        if format is not None:
1031            if not isinstance(format, basestring):
1032                raise TypeError("The frmat option must be be a string")
1033            if format not in ('text', 'csv', 'binary'):
1034                raise ValueError("Invalid format")
1035            options.append('format %s' % (format,))
1036        if sep is not None:
1037            if not isinstance(sep, basestring):
1038                raise TypeError("The sep option must be a string")
1039            if format == 'binary':
1040                raise ValueError(
1041                    "The sep option is not allowed with binary format")
1042            if len(sep) != 1:
1043                raise ValueError(
1044                    "The sep option must be a single one-byte character")
1045            options.append('delimiter %s')
1046            params.append(sep)
1047        if null is not None:
1048            if not isinstance(null, basestring):
1049                raise TypeError("The null option must be a string")
1050            options.append('null %s')
1051            params.append(null)
1052        if columns:
1053            if not isinstance(columns, basestring):
1054                columns = ','.join('"%s"' % (col,) for col in columns)
1055            operation.append('(%s)' % (columns,))
1056        operation.append("from stdin")
1057        if options:
1058            operation.append('(%s)' % ','.join(options))
1059        operation = ' '.join(operation)
1060
1061        putdata = self._src.putdata
1062        self.execute(operation, params)
1063
1064        try:
1065            for chunk in chunks():
1066                putdata(chunk)
1067        except BaseException as error:
1068            self.rowcount = -1
1069            # the following call will re-raise the error
1070            putdata(error)
1071        else:
1072            self.rowcount = putdata(None)
1073
1074        # return the cursor object, so you can chain operations
1075        return self
1076
1077    def copy_to(self, stream, table,
1078            format=None, sep=None, null=None, decode=None, columns=None):
1079        """Copy data from the specified table to an output stream.
1080
1081        The output stream can be a file-like object with a write() method or
1082        it can also be None, in which case the method will return a generator
1083        yielding a row on each iteration.
1084
1085        Output will be returned as byte strings unless you set decode to true.
1086
1087        Note that you can also use a select query instead of the table name.
1088
1089        The format must be text, csv or binary. The sep option sets the
1090        column separator (delimiter) used in the non binary formats.
1091        The null option sets the textual representation of NULL in the output.
1092
1093        The copy operation can be restricted to a subset of columns. If no
1094        columns are specified, all of them will be copied.
1095        """
1096        binary_format = format == 'binary'
1097        if stream is not None:
1098            try:
1099                write = stream.write
1100            except AttributeError:
1101                raise TypeError("Need an output stream to copy to")
1102        if not table or not isinstance(table, basestring):
1103            raise TypeError("Need a table to copy to")
1104        if table.lower().startswith('select'):
1105            if columns:
1106                raise ValueError("Columns must be specified in the query")
1107            table = '(%s)' % (table,)
1108        else:
1109            table = '"%s"' % (table,)
1110        operation = ['copy %s' % (table,)]
1111        options = []
1112        params = []
1113        if format is not None:
1114            if not isinstance(format, basestring):
1115                raise TypeError("The format option must be a string")
1116            if format not in ('text', 'csv', 'binary'):
1117                raise ValueError("Invalid format")
1118            options.append('format %s' % (format,))
1119        if sep is not None:
1120            if not isinstance(sep, basestring):
1121                raise TypeError("The sep option must be a string")
1122            if binary_format:
1123                raise ValueError(
1124                    "The sep option is not allowed with binary format")
1125            if len(sep) != 1:
1126                raise ValueError(
1127                    "The sep option must be a single one-byte character")
1128            options.append('delimiter %s')
1129            params.append(sep)
1130        if null is not None:
1131            if not isinstance(null, basestring):
1132                raise TypeError("The null option must be a string")
1133            options.append('null %s')
1134            params.append(null)
1135        if decode is None:
1136            if format == 'binary':
1137                decode = False
1138            else:
1139                decode = str is unicode
1140        else:
1141            if not isinstance(decode, (int, bool)):
1142                raise TypeError("The decode option must be a boolean")
1143            if decode and binary_format:
1144                raise ValueError(
1145                    "The decode option is not allowed with binary format")
1146        if columns:
1147            if not isinstance(columns, basestring):
1148                columns = ','.join('"%s"' % (col,) for col in columns)
1149            operation.append('(%s)' % (columns,))
1150
1151        operation.append("to stdout")
1152        if options:
1153            operation.append('(%s)' % ','.join(options))
1154        operation = ' '.join(operation)
1155
1156        getdata = self._src.getdata
1157        self.execute(operation, params)
1158
1159        def copy():
1160            self.rowcount = 0
1161            while True:
1162                row = getdata(decode)
1163                if isinstance(row, int):
1164                    if self.rowcount != row:
1165                        self.rowcount = row
1166                    break
1167                self.rowcount += 1
1168                yield row
1169
1170        if stream is None:
1171            # no input stream, return the generator
1172            return copy()
1173
1174        # write the rows to the file-like input stream
1175        for row in copy():
1176            write(row)
1177
1178        # return the cursor object, so you can chain operations
1179        return self
1180
1181    def __next__(self):
1182        """Return the next row (support for the iteration protocol)."""
1183        res = self.fetchone()
1184        if res is None:
1185            raise StopIteration
1186        return res
1187
1188    # Note that since Python 2.6 the iterator protocol uses __next()__
1189    # instead of next(), we keep it only for backward compatibility of pgdb.
1190    next = __next__
1191
1192    @staticmethod
1193    def nextset():
1194        """Not supported."""
1195        raise NotSupportedError("The nextset() method is not supported")
1196
1197    @staticmethod
1198    def setinputsizes(sizes):
1199        """Not supported."""
1200        pass  # unsupported, but silently passed
1201
1202    @staticmethod
1203    def setoutputsize(size, column=0):
1204        """Not supported."""
1205        pass  # unsupported, but silently passed
1206
1207    @staticmethod
1208    def row_factory(row):
1209        """Process rows before they are returned.
1210
1211        You can overwrite this statically with a custom row factory, or
1212        you can build a row factory dynamically with build_row_factory().
1213
1214        For example, you can create a Cursor class that returns rows as
1215        Python dictionaries like this:
1216
1217            class DictCursor(pgdb.Cursor):
1218
1219                def row_factory(self, row):
1220                    return {desc[0]: value
1221                        for desc, value in zip(self.description, row)}
1222
1223            cur = DictCursor(con)  # get one DictCursor instance or
1224            con.cursor_type = DictCursor  # always use DictCursor instances
1225        """
1226        raise NotImplementedError
1227
1228    def build_row_factory(self):
1229        """Build a row factory based on the current description.
1230
1231        This implementation builds a row factory for creating named tuples.
1232        You can overwrite this method if you want to dynamically create
1233        different row factories whenever the column description changes.
1234        """
1235        colnames = self.colnames
1236        if colnames:
1237            try:
1238                try:
1239                    return namedtuple('Row', colnames, rename=True)._make
1240                except TypeError:  # Python 2.6 and 3.0 do not support rename
1241                    colnames = [v if v.isalnum() else 'column_%d' % n
1242                             for n, v in enumerate(colnames)]
1243                    return namedtuple('Row', colnames)._make
1244            except ValueError:  # there is still a problem with the field names
1245                colnames = ['column_%d' % n for n in range(len(colnames))]
1246                return namedtuple('Row', colnames)._make
1247
1248
1249CursorDescription = namedtuple('CursorDescription',
1250    ['name', 'type_code', 'display_size', 'internal_size',
1251     'precision', 'scale', 'null_ok'])
1252
1253
1254### Connection Objects
1255
1256class Connection(object):
1257    """Connection object."""
1258
1259    # expose the exceptions as attributes on the connection object
1260    Error = Error
1261    Warning = Warning
1262    InterfaceError = InterfaceError
1263    DatabaseError = DatabaseError
1264    InternalError = InternalError
1265    OperationalError = OperationalError
1266    ProgrammingError = ProgrammingError
1267    IntegrityError = IntegrityError
1268    DataError = DataError
1269    NotSupportedError = NotSupportedError
1270
1271    def __init__(self, cnx):
1272        """Create a database connection object."""
1273        self._cnx = cnx  # connection
1274        self._tnx = False  # transaction state
1275        self.type_cache = TypeCache(cnx)
1276        self.cursor_type = Cursor
1277        try:
1278            self._cnx.source()
1279        except Exception:
1280            raise _op_error("Invalid connection")
1281
1282    def __enter__(self):
1283        """Enter the runtime context for the connection object.
1284
1285        The runtime context can be used for running transactions.
1286        """
1287        return self
1288
1289    def __exit__(self, et, ev, tb):
1290        """Exit the runtime context for the connection object.
1291
1292        This does not close the connection, but it ends a transaction.
1293        """
1294        if et is None and ev is None and tb is None:
1295            self.commit()
1296        else:
1297            self.rollback()
1298
1299    def close(self):
1300        """Close the connection object."""
1301        if self._cnx:
1302            if self._tnx:
1303                try:
1304                    self.rollback()
1305                except DatabaseError:
1306                    pass
1307            self._cnx.close()
1308            self._cnx = None
1309        else:
1310            raise _op_error("Connection has been closed")
1311
1312    def commit(self):
1313        """Commit any pending transaction to the database."""
1314        if self._cnx:
1315            if self._tnx:
1316                self._tnx = False
1317                try:
1318                    self._cnx.source().execute("COMMIT")
1319                except DatabaseError:
1320                    raise
1321                except Exception:
1322                    raise _op_error("Can't commit")
1323        else:
1324            raise _op_error("Connection has been closed")
1325
1326    def rollback(self):
1327        """Roll back to the start of any pending transaction."""
1328        if self._cnx:
1329            if self._tnx:
1330                self._tnx = False
1331                try:
1332                    self._cnx.source().execute("ROLLBACK")
1333                except DatabaseError:
1334                    raise
1335                except Exception:
1336                    raise _op_error("Can't rollback")
1337        else:
1338            raise _op_error("Connection has been closed")
1339
1340    def cursor(self):
1341        """Return a new cursor object using the connection."""
1342        if self._cnx:
1343            try:
1344                return self.cursor_type(self)
1345            except Exception:
1346                raise _op_error("Invalid connection")
1347        else:
1348            raise _op_error("Connection has been closed")
1349
1350    if shortcutmethods:  # otherwise do not implement and document this
1351
1352        def execute(self, operation, params=None):
1353            """Shortcut method to run an operation on an implicit cursor."""
1354            cursor = self.cursor()
1355            cursor.execute(operation, params)
1356            return cursor
1357
1358        def executemany(self, operation, param_seq):
1359            """Shortcut method to run an operation against a sequence."""
1360            cursor = self.cursor()
1361            cursor.executemany(operation, param_seq)
1362            return cursor
1363
1364
1365### Module Interface
1366
1367_connect_ = connect
1368
1369def connect(dsn=None,
1370        user=None, password=None,
1371        host=None, database=None):
1372    """Connect to a database."""
1373    # first get params from DSN
1374    dbport = -1
1375    dbhost = ""
1376    dbbase = ""
1377    dbuser = ""
1378    dbpasswd = ""
1379    dbopt = ""
1380    try:
1381        params = dsn.split(":")
1382        dbhost = params[0]
1383        dbbase = params[1]
1384        dbuser = params[2]
1385        dbpasswd = params[3]
1386        dbopt = params[4]
1387    except (AttributeError, IndexError, TypeError):
1388        pass
1389
1390    # override if necessary
1391    if user is not None:
1392        dbuser = user
1393    if password is not None:
1394        dbpasswd = password
1395    if database is not None:
1396        dbbase = database
1397    if host is not None:
1398        try:
1399            params = host.split(":")
1400            dbhost = params[0]
1401            dbport = int(params[1])
1402        except (AttributeError, IndexError, TypeError, ValueError):
1403            pass
1404
1405    # empty host is localhost
1406    if dbhost == "":
1407        dbhost = None
1408    if dbuser == "":
1409        dbuser = None
1410
1411    # open the connection
1412    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
1413    return Connection(cnx)
1414
1415
1416### Types Handling
1417
1418class Type(frozenset):
1419    """Type class for a couple of PostgreSQL data types.
1420
1421    PostgreSQL is object-oriented: types are dynamic.
1422    We must thus use type names as internal type codes.
1423    """
1424
1425    def __new__(cls, values):
1426        if isinstance(values, basestring):
1427            values = values.split()
1428        return super(Type, cls).__new__(cls, values)
1429
1430    def __eq__(self, other):
1431        if isinstance(other, basestring):
1432            if other.startswith('_'):
1433                other = other[1:]
1434            return other in self
1435        else:
1436            return super(Type, self).__eq__(other)
1437
1438    def __ne__(self, other):
1439        if isinstance(other, basestring):
1440            if other.startswith('_'):
1441                other = other[1:]
1442            return other not in self
1443        else:
1444            return super(Type, self).__ne__(other)
1445
1446
1447class ArrayType:
1448    """Type class for PostgreSQL array types."""
1449
1450    def __eq__(self, other):
1451        if isinstance(other, basestring):
1452            return other.startswith('_')
1453        else:
1454            return isinstance(other, ArrayType)
1455
1456    def __ne__(self, other):
1457        if isinstance(other, basestring):
1458            return not other.startswith('_')
1459        else:
1460            return not isinstance(other, ArrayType)
1461
1462
1463class RecordType:
1464    """Type class for PostgreSQL record types."""
1465
1466    def __eq__(self, other):
1467        if isinstance(other, TypeCode):
1468            return other.type == 'c'
1469        elif isinstance(other, basestring):
1470            return other == 'record'
1471        else:
1472            return isinstance(other, RecordType)
1473
1474    def __ne__(self, other):
1475        if isinstance(other, TypeCode):
1476            return other.type != 'c'
1477        elif isinstance(other, basestring):
1478            return other != 'record'
1479        else:
1480            return not isinstance(other, RecordType)
1481
1482
1483# Mandatory type objects defined by DB-API 2 specs:
1484
1485STRING = Type('char bpchar name text varchar')
1486BINARY = Type('bytea')
1487NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1488DATETIME = Type('date time timetz timestamp timestamptz interval'
1489    ' abstime reltime')  # these are very old
1490ROWID = Type('oid')
1491
1492
1493# Additional type objects (more specific):
1494
1495BOOL = Type('bool')
1496SMALLINT = Type('int2')
1497INTEGER = Type('int2 int4 int8 serial')
1498LONG = Type('int8')
1499FLOAT = Type('float4 float8')
1500NUMERIC = Type('numeric')
1501MONEY = Type('money')
1502DATE = Type('date')
1503TIME = Type('time timetz')
1504TIMESTAMP = Type('timestamp timestamptz')
1505INTERVAL = Type('interval')
1506JSON = Type('json jsonb')
1507
1508# Type object for arrays (also equate to their base types):
1509
1510ARRAY = ArrayType()
1511
1512# Type object for records (encompassing all composite types):
1513
1514RECORD = RecordType()
1515
1516
1517# Mandatory type helpers defined by DB-API 2 specs:
1518
1519def Date(year, month, day):
1520    """Construct an object holding a date value."""
1521    return date(year, month, day)
1522
1523
1524def Time(hour, minute=0, second=0, microsecond=0):
1525    """Construct an object holding a time value."""
1526    return time(hour, minute, second, microsecond)
1527
1528
1529def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1530    """Construct an object holding a time stamp value."""
1531    return datetime(year, month, day, hour, minute, second, microsecond)
1532
1533
1534def DateFromTicks(ticks):
1535    """Construct an object holding a date value from the given ticks value."""
1536    return Date(*localtime(ticks)[:3])
1537
1538
1539def TimeFromTicks(ticks):
1540    """Construct an object holding a time value from the given ticks value."""
1541    return Time(*localtime(ticks)[3:6])
1542
1543
1544def TimestampFromTicks(ticks):
1545    """Construct an object holding a time stamp from the given ticks value."""
1546    return Timestamp(*localtime(ticks)[:6])
1547
1548
1549class Binary(bytes):
1550    """Construct an object capable of holding a binary (long) string value."""
1551
1552
1553# Additional type helpers for PyGreSQL:
1554
1555def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0):
1556    """Construct an object holding a time inverval value."""
1557    return timedelta(days, hours=hours, minutes=minutes, seconds=seconds,
1558        microseconds=microseconds)
1559
1560class Bytea(bytes):
1561    """Construct an object capable of holding a bytea value."""
1562
1563
1564class Json:
1565    """Construct a wrapper for holding an object serializable to JSON."""
1566
1567    def __init__(self, obj, encode=None):
1568        self.obj = obj
1569        self.encode = encode or jsonencode
1570
1571    def __str__(self):
1572        obj = self.obj
1573        if isinstance(obj, basestring):
1574            return obj
1575        return self.encode(obj)
1576
1577    __pg_repr__ = __str__
1578
1579
1580class Literal:
1581    """Construct a wrapper for holding a literal SQL string."""
1582
1583    def __init__(self, sql):
1584        self.sql = sql
1585
1586    def __str__(self):
1587        return self.sql
1588
1589    __pg_repr__ = __str__
1590
1591
1592# If run as script, print some information:
1593
1594if __name__ == '__main__':
1595    print('PyGreSQL version', version)
1596    print('')
1597    print(__doc__)
Note: See TracBrowser for help on using the repository browser.