source: trunk/pgdb.py @ 849

Last change on this file since 849 was 849, checked in by cito, 4 years ago

Make timetz and timestamptz work properly with Python 2

Python 2 has no concrete timezone class, therefore we had so far returned only
naive datetimes in this case. This patch adds a simple concrete timezone class,
so we can return timetz and timestamptz with timezones in Python 2 as well.

  • Property svn:keywords set to Id
File size: 55.3 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 849 2016-02-09 13:14:26Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function, division
67
68from _pg import *
69
70__version__ = version
71
72from datetime import date, time, datetime, timedelta, tzinfo
73from time import localtime
74from decimal import Decimal
75from uuid import UUID as Uuid
76from math import isnan, isinf
77from collections import namedtuple
78from functools import partial
79from re import compile as regex
80from json import loads as jsondecode, dumps as jsonencode
81
82try:
83    long
84except NameError:  # Python >= 3.0
85    long = int
86
87try:
88    unicode
89except NameError:  # Python >= 3.0
90    unicode = str
91
92try:
93    basestring
94except NameError:  # Python >= 3.0
95    basestring = (str, bytes)
96
97from collections import Iterable
98
99
100### Module Constants
101
102# compliant with DB API 2.0
103apilevel = '2.0'
104
105# module may be shared, but not connections
106threadsafety = 1
107
108# this module use extended python format codes
109paramstyle = 'pyformat'
110
111# shortcut methods have been excluded from DB API 2 and
112# are not recommended by the DB SIG, but they can be handy
113shortcutmethods = 1
114
115
116### Internal Type Handling
117
118try:
119    from inspect import signature
120except ImportError:  # Python < 3.3
121    from inspect import getargspec
122
123    def get_args(func):
124        return getargspec(func).args
125else:
126
127    def get_args(func):
128        return list(signature(func).parameters)
129
130try:
131    from datetime import timezone
132except ImportError:  # Python < 3.2
133
134    class timezone(tzinfo):
135        """Simple timezone implementation."""
136
137        def __init__(self, offset, name=None):
138            self.offset = offset
139            if not name:
140                minutes = self.offset.days * 1440 + self.offset.seconds // 60
141                if minutes < 0:
142                    hours, minutes = divmod(-minutes, 60)
143                    hours = -hours
144                else:
145                    hours, minutes = divmod(minutes, 60)
146                name = 'UTC%+03d:%02d' % (hours, minutes)
147            self.name = name
148
149        def utcoffset(self, dt):
150            return self.offset
151
152        def tzname(self, dt):
153            return self.name
154
155        def dst(self, dt):
156            return None
157
158    timezone.utc = timezone(timedelta(0), 'UTC')
159
160    _has_timezone = False
161else:
162    _has_timezone = True
163
164# time zones used in Postgres timestamptz output
165_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
166    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
167    UCT='+0000', UTC='+0000', WET='+0000')
168
169
170def _timezone_as_offset(tz):
171    if tz.startswith(('+', '-')):
172        if len(tz) < 5:
173            return tz + '00'
174        return tz.replace(':', '')
175    return _timezones.get(tz, '+0000')
176
177
178def _get_timezone(tz):
179    tz = _timezone_as_offset(tz)
180    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
181    if tz[0] == '-':
182        minutes = -minutes
183    return timezone(timedelta(minutes=minutes), tz)
184
185
186def decimal_type(decimal_type=None):
187    """Get or set global type to be used for decimal values.
188
189    Note that connections cache cast functions. To be sure a global change
190    is picked up by a running connection, call con.type_cache.reset_typecast().
191    """
192    global Decimal
193    if decimal_type is not None:
194        Decimal = decimal_type
195        set_typecast('numeric', decimal_type)
196    return Decimal
197
198
199def cast_bool(value):
200    """Cast boolean value in database format to bool."""
201    if value:
202        return value[0] in ('t', 'T')
203
204
205def cast_money(value):
206    """Cast money value in database format to Decimal."""
207    if value:
208        value = value.replace('(', '-')
209        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
210
211
212def cast_int2vector(value):
213    """Cast an int2vector value."""
214    return [int(v) for v in value.split()]
215
216
217def cast_date(value, connection):
218    """Cast a date value."""
219    # The output format depends on the server setting DateStyle.  The default
220    # setting ISO and the setting for German are actually unambiguous.  The
221    # order of days and months in the other two settings is however ambiguous,
222    # so at least here we need to consult the setting to properly parse values.
223    if value == '-infinity':
224        return date.min
225    if value == 'infinity':
226        return date.max
227    value = value.split()
228    if value[-1] == 'BC':
229        return date.min
230    value = value[0]
231    if len(value) > 10:
232        return date.max
233    fmt = connection.date_format()
234    return datetime.strptime(value, fmt).date()
235
236
237def cast_time(value):
238    """Cast a time value."""
239    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
240    return datetime.strptime(value, fmt).time()
241
242
243_re_timezone = regex('(.*)([+-].*)')
244
245
246def cast_timetz(value):
247    """Cast a timetz value."""
248    tz = _re_timezone.match(value)
249    if tz:
250        value, tz = tz.groups()
251    else:
252        tz = '+0000'
253    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
254    if _has_timezone:
255        value += _timezone_as_offset(tz)
256        fmt += '%z'
257        return datetime.strptime(value, fmt).timetz()
258    return datetime.strptime(value, fmt).timetz().replace(
259        tzinfo=_get_timezone(tz))
260
261
262def cast_timestamp(value, connection):
263    """Cast a timestamp value."""
264    if value == '-infinity':
265        return datetime.min
266    if value == 'infinity':
267        return datetime.max
268    value = value.split()
269    if value[-1] == 'BC':
270        return datetime.min
271    fmt = connection.date_format()
272    if fmt.endswith('-%Y') and len(value) > 2:
273        value = value[1:5]
274        if len(value[3]) > 4:
275            return datetime.max
276        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
277            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
278    else:
279        if len(value[0]) > 10:
280            return datetime.max
281        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
282    return datetime.strptime(' '.join(value), ' '.join(fmt))
283
284
285def cast_timestamptz(value, connection):
286    """Cast a timestamptz value."""
287    if value == '-infinity':
288        return datetime.min
289    if value == 'infinity':
290        return datetime.max
291    value = value.split()
292    if value[-1] == 'BC':
293        return datetime.min
294    fmt = connection.date_format()
295    if fmt.endswith('-%Y') and len(value) > 2:
296        value = value[1:]
297        if len(value[3]) > 4:
298            return datetime.max
299        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
300            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
301        value, tz = value[:-1], value[-1]
302    else:
303        if fmt.startswith('%Y-'):
304            tz = _re_timezone.match(value[1])
305            if tz:
306                value[1], tz = tz.groups()
307            else:
308                tz = '+0000'
309        else:
310            value, tz = value[:-1], value[-1]
311        if len(value[0]) > 10:
312            return datetime.max
313        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
314    if _has_timezone:
315        value.append(_timezone_as_offset(tz))
316        fmt.append('%z')
317        return datetime.strptime(' '.join(value), ' '.join(fmt))
318    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
319        tzinfo=_get_timezone(tz))
320
321
322_re_interval_sql_standard = regex(
323    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
324    '(?:([+-]?[0-9]+)(?!:) ?)?'
325    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
326
327_re_interval_postgres = regex(
328    '(?:([+-]?[0-9]+) ?years? ?)?'
329    '(?:([+-]?[0-9]+) ?mons? ?)?'
330    '(?:([+-]?[0-9]+) ?days? ?)?'
331    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
332
333_re_interval_postgres_verbose = regex(
334    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
335    '(?:([+-]?[0-9]+) ?mons? ?)?'
336    '(?:([+-]?[0-9]+) ?days? ?)?'
337    '(?:([+-]?[0-9]+) ?hours? ?)?'
338    '(?:([+-]?[0-9]+) ?mins? ?)?'
339    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
340
341_re_interval_iso_8601 = regex(
342    'P(?:([+-]?[0-9]+)Y)?'
343    '(?:([+-]?[0-9]+)M)?'
344    '(?:([+-]?[0-9]+)D)?'
345    '(?:T(?:([+-]?[0-9]+)H)?'
346    '(?:([+-]?[0-9]+)M)?'
347    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
348
349
350def cast_interval(value):
351    """Cast an interval value."""
352    # The output format depends on the server setting IntervalStyle, but it's
353    # not necessary to consult this setting to parse it.  It's faster to just
354    # check all possible formats, and there is no ambiguity here.
355    m = _re_interval_iso_8601.match(value)
356    if m:
357        m = [d or '0' for d in m.groups()]
358        secs_ago = m.pop(5) == '-'
359        m = [int(d) for d in m]
360        years, mons, days, hours, mins, secs, usecs = m
361        if secs_ago:
362            secs = -secs
363            usecs = -usecs
364    else:
365        m = _re_interval_postgres_verbose.match(value)
366        if m:
367            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
368            secs_ago = m.pop(5) == '-'
369            m = [-int(d) for d in m] if ago else [int(d) for d in m]
370            years, mons, days, hours, mins, secs, usecs = m
371            if secs_ago:
372                secs = - secs
373                usecs = -usecs
374        else:
375            m = _re_interval_postgres.match(value)
376            if m and any(m.groups()):
377                m = [d or '0' for d in m.groups()]
378                hours_ago = m.pop(3) == '-'
379                m = [int(d) for d in m]
380                years, mons, days, hours, mins, secs, usecs = m
381                if hours_ago:
382                    hours = -hours
383                    mins = -mins
384                    secs = -secs
385                    usecs = -usecs
386            else:
387                m = _re_interval_sql_standard.match(value)
388                if m and any(m.groups()):
389                    m = [d or '0' for d in m.groups()]
390                    years_ago = m.pop(0) == '-'
391                    hours_ago = m.pop(3) == '-'
392                    m = [int(d) for d in m]
393                    years, mons, days, hours, mins, secs, usecs = m
394                    if years_ago:
395                        years = -years
396                        mons = -mons
397                    if hours_ago:
398                        hours = -hours
399                        mins = -mins
400                        secs = -secs
401                        usecs = -usecs
402                else:
403                    raise ValueError('Cannot parse interval: %s' % value)
404    days += 365 * years + 30 * mons
405    return timedelta(days=days, hours=hours, minutes=mins,
406        seconds=secs, microseconds=usecs)
407
408
409class Typecasts(dict):
410    """Dictionary mapping database types to typecast functions.
411
412    The cast functions get passed the string representation of a value in
413    the database which they need to convert to a Python object.  The
414    passed string will never be None since NULL values are already be
415    handled before the cast function is called.
416    """
417
418    # the default cast functions
419    # (str functions are ignored but have been added for faster access)
420    defaults = {'char': str, 'bpchar': str, 'name': str,
421        'text': str, 'varchar': str,
422        'bool': cast_bool, 'bytea': unescape_bytea,
423        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
424        'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode,
425        'float4': float, 'float8': float,
426        'numeric': Decimal, 'money': cast_money,
427        'date': cast_date, 'interval': cast_interval,
428        'time': cast_time, 'timetz': cast_timetz,
429        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
430        'int2vector': cast_int2vector, 'uuid': Uuid,
431        'anyarray': cast_array, 'record': cast_record}
432
433    connection = None  # will be set in local connection specific instances
434
435    def __missing__(self, typ):
436        """Create a cast function if it is not cached.
437
438        Note that this class never raises a KeyError,
439        but returns None when no special cast function exists.
440        """
441        if not isinstance(typ, str):
442            raise TypeError('Invalid type: %s' % typ)
443        cast = self.defaults.get(typ)
444        if cast:
445            # store default for faster access
446            cast = self._add_connection(cast)
447            self[typ] = cast
448        elif typ.startswith('_'):
449            # create array cast
450            base_cast = self[typ[1:]]
451            cast = self.create_array_cast(base_cast)
452            if base_cast:
453                # store only if base type exists
454                self[typ] = cast
455        return cast
456
457    @staticmethod
458    def _needs_connection(func):
459        """Check if a typecast function needs a connection argument."""
460        try:
461            args = get_args(func)
462        except (TypeError, ValueError):
463            return False
464        else:
465            return 'connection' in args[1:]
466
467    def _add_connection(self, cast):
468        """Add a connection argument to the typecast function if necessary."""
469        if not self.connection or not self._needs_connection(cast):
470            return cast
471        return partial(cast, connection=self.connection)
472
473    def get(self, typ, default=None):
474        """Get the typecast function for the given database type."""
475        return self[typ] or default
476
477    def set(self, typ, cast):
478        """Set a typecast function for the specified database type(s)."""
479        if isinstance(typ, basestring):
480            typ = [typ]
481        if cast is None:
482            for t in typ:
483                self.pop(t, None)
484                self.pop('_%s' % t, None)
485        else:
486            if not callable(cast):
487                raise TypeError("Cast parameter must be callable")
488            for t in typ:
489                self[t] = self._add_connection(cast)
490                self.pop('_%s' % t, None)
491
492    def reset(self, typ=None):
493        """Reset the typecasts for the specified type(s) to their defaults.
494
495        When no type is specified, all typecasts will be reset.
496        """
497        defaults = self.defaults
498        if typ is None:
499            self.clear()
500            self.update(defaults)
501        else:
502            if isinstance(typ, basestring):
503                typ = [typ]
504            for t in typ:
505                cast = defaults.get(t)
506                if cast:
507                    self[t] = self._add_connection(cast)
508                    t = '_%s' % t
509                    cast = defaults.get(t)
510                    if cast:
511                        self[t] = self._add_connection(cast)
512                    else:
513                        self.pop(t, None)
514                else:
515                    self.pop(t, None)
516                    self.pop('_%s' % t, None)
517
518    def create_array_cast(self, basecast):
519        """Create an array typecast for the given base cast."""
520        def cast(v):
521            return cast_array(v, basecast)
522        return cast
523
524    def create_record_cast(self, name, fields, casts):
525        """Create a named record typecast for the given fields and casts."""
526        record = namedtuple(name, fields)
527        def cast(v):
528            return record(*cast_record(v, casts))
529        return cast
530
531
532_typecasts = Typecasts()  # this is the global typecast dictionary
533
534
535def get_typecast(typ):
536    """Get the global typecast function for the given database type(s)."""
537    return _typecasts.get(typ)
538
539
540def set_typecast(typ, cast):
541    """Set a global typecast function for the given database type(s).
542
543    Note that connections cache cast functions. To be sure a global change
544    is picked up by a running connection, call con.type_cache.reset_typecast().
545    """
546    _typecasts.set(typ, cast)
547
548
549def reset_typecast(typ=None):
550    """Reset the global typecasts for the given type(s) to their default.
551
552    When no type is specified, all typecasts will be reset.
553
554    Note that connections cache cast functions. To be sure a global change
555    is picked up by a running connection, call con.type_cache.reset_typecast().
556    """
557    _typecasts.reset(typ)
558
559
560class LocalTypecasts(Typecasts):
561    """Map typecasts, including local composite types, to cast functions."""
562
563    defaults = _typecasts
564
565    connection = None  # will be set in a connection specific instance
566
567    def __missing__(self, typ):
568        """Create a cast function if it is not cached."""
569        if typ.startswith('_'):
570            base_cast = self[typ[1:]]
571            cast = self.create_array_cast(base_cast)
572            if base_cast:
573                self[typ] = cast
574        else:
575            cast = self.defaults.get(typ)
576            if cast:
577                cast = self._add_connection(cast)
578                self[typ] = cast
579            else:
580                fields = self.get_fields(typ)
581                if fields:
582                    casts = [self[field.type] for field in fields]
583                    fields = [field.name for field in fields]
584                    cast = self.create_record_cast(typ, fields, casts)
585                    self[typ] = cast
586        return cast
587
588    def get_fields(self, typ):
589        """Return the fields for the given record type.
590
591        This method will be replaced with a method that looks up the fields
592        using the type cache of the connection.
593        """
594        return []
595
596
597class TypeCode(str):
598    """Class representing the type_code used by the DB-API 2.0.
599
600    TypeCode objects are strings equal to the PostgreSQL type name,
601    but carry some additional information.
602    """
603
604    @classmethod
605    def create(cls, oid, name, len, type, category, delim, relid):
606        """Create a type code for a PostgreSQL data type."""
607        self = cls(name)
608        self.oid = oid
609        self.len = len
610        self.type = type
611        self.category = category
612        self.delim = delim
613        self.relid = relid
614        return self
615
616FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
617
618
619class TypeCache(dict):
620    """Cache for database types.
621
622    This cache maps type OIDs and names to TypeCode strings containing
623    important information on the associated database type.
624    """
625
626    def __init__(self, cnx):
627        """Initialize type cache for connection."""
628        super(TypeCache, self).__init__()
629        self._escape_string = cnx.escape_string
630        self._src = cnx.source()
631        self._typecasts = LocalTypecasts()
632        self._typecasts.get_fields = self.get_fields
633        self._typecasts.connection = cnx
634
635    def __missing__(self, key):
636        """Get the type info from the database if it is not cached."""
637        if isinstance(key, int):
638            oid = key
639        else:
640            if '.' not in key and '"' not in key:
641                key = '"%s"' % key
642            oid = "'%s'::regtype" % self._escape_string(key)
643        try:
644            self._src.execute("SELECT oid, typname,"
645                 " typlen, typtype, typcategory, typdelim, typrelid"
646                " FROM pg_type WHERE oid=%s" % oid)
647        except ProgrammingError:
648            res = None
649        else:
650            res = self._src.fetch(1)
651        if not res:
652            raise KeyError('Type %s could not be found' % key)
653        res = res[0]
654        type_code = TypeCode.create(int(res[0]), res[1],
655            int(res[2]), res[3], res[4], res[5], int(res[6]))
656        self[type_code.oid] = self[str(type_code)] = type_code
657        return type_code
658
659    def get(self, key, default=None):
660        """Get the type even if it is not cached."""
661        try:
662            return self[key]
663        except KeyError:
664            return default
665
666    def get_fields(self, typ):
667        """Get the names and types of the fields of composite types."""
668        if not isinstance(typ, TypeCode):
669            typ = self.get(typ)
670            if not typ:
671                return None
672        if not typ.relid:
673            return None  # this type is not composite
674        self._src.execute("SELECT attname, atttypid"
675            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
676            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
677        return [FieldInfo(name, self.get(int(oid)))
678            for name, oid in self._src.fetch(-1)]
679
680    def get_typecast(self, typ):
681        """Get the typecast function for the given database type."""
682        return self._typecasts.get(typ)
683
684    def set_typecast(self, typ, cast):
685        """Set a typecast function for the specified database type(s)."""
686        self._typecasts.set(typ, cast)
687
688    def reset_typecast(self, typ=None):
689        """Reset the typecast function for the specified database type(s)."""
690        self._typecasts.reset(typ)
691
692    def typecast(self, value, typ):
693        """Cast the given value according to the given database type."""
694        if value is None:
695            # for NULL values, no typecast is necessary
696            return None
697        cast = self.get_typecast(typ)
698        if not cast or cast is str:
699            # no typecast is necessary
700            return value
701        return cast(value)
702
703
704class _quotedict(dict):
705    """Dictionary with auto quoting of its items.
706
707    The quote attribute must be set to the desired quote function.
708    """
709
710    def __getitem__(self, key):
711        return self.quote(super(_quotedict, self).__getitem__(key))
712
713
714### Error messages
715
716def _db_error(msg, cls=DatabaseError):
717    """Return DatabaseError with empty sqlstate attribute."""
718    error = cls(msg)
719    error.sqlstate = None
720    return error
721
722
723def _op_error(msg):
724    """Return OperationalError."""
725    return _db_error(msg, OperationalError)
726
727
728### Cursor Object
729
730class Cursor(object):
731    """Cursor object."""
732
733    def __init__(self, dbcnx):
734        """Create a cursor object for the database connection."""
735        self.connection = self._dbcnx = dbcnx
736        self._cnx = dbcnx._cnx
737        self.type_cache = dbcnx.type_cache
738        self._src = self._cnx.source()
739        # the official attribute for describing the result columns
740        self._description = None
741        if self.row_factory is Cursor.row_factory:
742            # the row factory needs to be determined dynamically
743            self.row_factory = None
744        else:
745            self.build_row_factory = None
746        self.rowcount = -1
747        self.arraysize = 1
748        self.lastrowid = None
749
750    def __iter__(self):
751        """Make cursor compatible to the iteration protocol."""
752        return self
753
754    def __enter__(self):
755        """Enter the runtime context for the cursor object."""
756        return self
757
758    def __exit__(self, et, ev, tb):
759        """Exit the runtime context for the cursor object."""
760        self.close()
761
762    def _quote(self, value):
763        """Quote value depending on its type."""
764        if value is None:
765            return 'NULL'
766        if isinstance(value, (Hstore, Json)):
767            value = str(value)
768        if isinstance(value, basestring):
769            if isinstance(value, Binary):
770                value = self._cnx.escape_bytea(value)
771                if bytes is not str:  # Python >= 3.0
772                    value = value.decode('ascii')
773            else:
774                value = self._cnx.escape_string(value)
775            return "'%s'" % value
776        if isinstance(value, float):
777            if isinf(value):
778                return "'-Infinity'" if value < 0 else "'Infinity'"
779            if isnan(value):
780                return "'NaN'"
781            return value
782        if isinstance(value, (int, long, Decimal, Literal)):
783            return value
784        if isinstance(value, datetime):
785            if value.tzinfo:
786                return "'%s'::timestamptz" % value
787            return "'%s'::timestamp" % value
788        if isinstance(value, date):
789            return "'%s'::date" % value
790        if isinstance(value, time):
791            if value.tzinfo:
792                return "'%s'::timetz" % value
793            return "'%s'::time" % value
794        if isinstance(value, timedelta):
795            return "'%s'::interval" % value
796        if isinstance(value, Uuid):
797            return "'%s'::uuid" % value
798        if isinstance(value, list):
799            # Quote value as an ARRAY constructor. This is better than using
800            # an array literal because it carries the information that this is
801            # an array and not a string.  One issue with this syntax is that
802            # you need to add an explicit typecast when passing empty arrays.
803            # The ARRAY keyword is actually only necessary at the top level.
804            if not value:  # exception for empty array
805                return "'{}'"
806            q = self._quote
807            try:
808                return 'ARRAY[%s]' % ','.join(str(q(v)) for v in value)
809            except UnicodeEncodeError:  # Python 2 with non-ascii values
810                return u'ARRAY[%s]' % ','.join(unicode(q(v)) for v in value)
811        if isinstance(value, tuple):
812            # Quote as a ROW constructor.  This is better than using a record
813            # literal because it carries the information that this is a record
814            # and not a string.  We don't use the keyword ROW in order to make
815            # this usable with the IN syntax as well.  It is only necessary
816            # when the records has a single column which is not really useful.
817            q = self._quote
818            try:
819                return '(%s)' % ','.join(str(q(v)) for v in value)
820            except UnicodeEncodeError:  # Python 2 with non-ascii values
821                return u'(%s)' % ','.join(unicode(q(v)) for v in value)
822        try:
823            value = value.__pg_repr__()
824        except AttributeError:
825            raise InterfaceError(
826                'Do not know how to adapt type %s' % type(value))
827        if isinstance(value, (tuple, list)):
828            value = self._quote(value)
829        return value
830
831    def _quoteparams(self, string, parameters):
832        """Quote parameters.
833
834        This function works for both mappings and sequences.
835
836        The function should be used even when there are no parameters,
837        so that we have a consistent behavior regarding percent signs.
838        """
839        if not parameters:
840            try:
841                return string % ()  # unescape literal quotes if possible
842            except (TypeError, ValueError):
843                return string  # silently accept unescaped quotes
844        if isinstance(parameters, dict):
845            parameters = _quotedict(parameters)
846            parameters.quote = self._quote
847        else:
848            parameters = tuple(map(self._quote, parameters))
849        return string % parameters
850
851    def _make_description(self, info):
852        """Make the description tuple for the given field info."""
853        name, typ, size, mod = info[1:]
854        type_code = self.type_cache[typ]
855        if mod > 0:
856            mod -= 4
857        if type_code == 'numeric':
858            precision, scale = mod >> 16, mod & 0xffff
859            size = precision
860        else:
861            if not size:
862                size = type_code.size
863            if size == -1:
864                size = mod
865            precision = scale = None
866        return CursorDescription(name, type_code,
867            None, size, precision, scale, None)
868
869    @property
870    def description(self):
871        """Read-only attribute describing the result columns."""
872        descr = self._description
873        if self._description is True:
874            make = self._make_description
875            descr = [make(info) for info in self._src.listinfo()]
876            self._description = descr
877        return descr
878
879    @property
880    def colnames(self):
881        """Unofficial convenience method for getting the column names."""
882        return [d[0] for d in self.description]
883
884    @property
885    def coltypes(self):
886        """Unofficial convenience method for getting the column types."""
887        return [d[1] for d in self.description]
888
889    def close(self):
890        """Close the cursor object."""
891        self._src.close()
892
893    def execute(self, operation, parameters=None):
894        """Prepare and execute a database operation (query or command)."""
895        # The parameters may also be specified as list of tuples to e.g.
896        # insert multiple rows in a single operation, but this kind of
897        # usage is deprecated.  We make several plausibility checks because
898        # tuples can also be passed with the meaning of ROW constructors.
899        if (parameters and isinstance(parameters, list)
900                and len(parameters) > 1
901                and all(isinstance(p, tuple) for p in parameters)
902                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
903            return self.executemany(operation, parameters)
904        else:
905            # not a list of tuples
906            return self.executemany(operation, [parameters])
907
908    def executemany(self, operation, seq_of_parameters):
909        """Prepare operation and execute it against a parameter sequence."""
910        if not seq_of_parameters:
911            # don't do anything without parameters
912            return
913        self._description = None
914        self.rowcount = -1
915        # first try to execute all queries
916        rowcount = 0
917        sql = "BEGIN"
918        try:
919            if not self._dbcnx._tnx:
920                try:
921                    self._cnx.source().execute(sql)
922                except DatabaseError:
923                    raise  # database provides error message
924                except Exception:
925                    raise _op_error("Can't start transaction")
926                self._dbcnx._tnx = True
927            for parameters in seq_of_parameters:
928                sql = operation
929                sql = self._quoteparams(sql, parameters)
930                rows = self._src.execute(sql)
931                if rows:  # true if not DML
932                    rowcount += rows
933                else:
934                    self.rowcount = -1
935        except DatabaseError:
936            raise  # database provides error message
937        except Error as err:
938            raise _db_error(
939                "Error in '%s': '%s' " % (sql, err), InterfaceError)
940        except Exception as err:
941            raise _op_error("Internal error in '%s': %s" % (sql, err))
942        # then initialize result raw count and description
943        if self._src.resulttype == RESULT_DQL:
944            self._description = True  # fetch on demand
945            self.rowcount = self._src.ntuples
946            self.lastrowid = None
947            if self.build_row_factory:
948                self.row_factory = self.build_row_factory()
949        else:
950            self.rowcount = rowcount
951            self.lastrowid = self._src.oidstatus()
952        # return the cursor object, so you can write statements such as
953        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
954        return self
955
956    def fetchone(self):
957        """Fetch the next row of a query result set."""
958        res = self.fetchmany(1, False)
959        try:
960            return res[0]
961        except IndexError:
962            return None
963
964    def fetchall(self):
965        """Fetch all (remaining) rows of a query result."""
966        return self.fetchmany(-1, False)
967
968    def fetchmany(self, size=None, keep=False):
969        """Fetch the next set of rows of a query result.
970
971        The number of rows to fetch per call is specified by the
972        size parameter. If it is not given, the cursor's arraysize
973        determines the number of rows to be fetched. If you set
974        the keep parameter to true, this is kept as new arraysize.
975        """
976        if size is None:
977            size = self.arraysize
978        if keep:
979            self.arraysize = size
980        try:
981            result = self._src.fetch(size)
982        except DatabaseError:
983            raise
984        except Error as err:
985            raise _db_error(str(err))
986        typecast = self.type_cache.typecast
987        return [self.row_factory([typecast(value, typ)
988            for typ, value in zip(self.coltypes, row)]) for row in result]
989
990    def callproc(self, procname, parameters=None):
991        """Call a stored database procedure with the given name.
992
993        The sequence of parameters must contain one entry for each input
994        argument that the procedure expects. The result of the call is the
995        same as this input sequence; replacement of output and input/output
996        parameters in the return value is currently not supported.
997
998        The procedure may also provide a result set as output. These can be
999        requested through the standard fetch methods of the cursor.
1000        """
1001        n = parameters and len(parameters) or 0
1002        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
1003        self.execute(query, parameters)
1004        return parameters
1005
1006    def copy_from(self, stream, table,
1007            format=None, sep=None, null=None, size=None, columns=None):
1008        """Copy data from an input stream to the specified table.
1009
1010        The input stream can be a file-like object with a read() method or
1011        it can also be an iterable returning a row or multiple rows of input
1012        on each iteration.
1013
1014        The format must be text, csv or binary. The sep option sets the
1015        column separator (delimiter) used in the non binary formats.
1016        The null option sets the textual representation of NULL in the input.
1017
1018        The size option sets the size of the buffer used when reading data
1019        from file-like objects.
1020
1021        The copy operation can be restricted to a subset of columns. If no
1022        columns are specified, all of them will be copied.
1023        """
1024        binary_format = format == 'binary'
1025        try:
1026            read = stream.read
1027        except AttributeError:
1028            if size:
1029                raise ValueError("Size must only be set for file-like objects")
1030            if binary_format:
1031                input_type = bytes
1032                type_name = 'byte strings'
1033            else:
1034                input_type = basestring
1035                type_name = 'strings'
1036
1037            if isinstance(stream, basestring):
1038                if not isinstance(stream, input_type):
1039                    raise ValueError("The input must be %s" % type_name)
1040                if not binary_format:
1041                    if isinstance(stream, str):
1042                        if not stream.endswith('\n'):
1043                            stream += '\n'
1044                    else:
1045                        if not stream.endswith(b'\n'):
1046                            stream += b'\n'
1047
1048                def chunks():
1049                    yield stream
1050
1051            elif isinstance(stream, Iterable):
1052
1053                def chunks():
1054                    for chunk in stream:
1055                        if not isinstance(chunk, input_type):
1056                            raise ValueError(
1057                                "Input stream must consist of %s" % type_name)
1058                        if isinstance(chunk, str):
1059                            if not chunk.endswith('\n'):
1060                                chunk += '\n'
1061                        else:
1062                            if not chunk.endswith(b'\n'):
1063                                chunk += b'\n'
1064                        yield chunk
1065
1066            else:
1067                raise TypeError("Need an input stream to copy from")
1068        else:
1069            if size is None:
1070                size = 8192
1071            elif not isinstance(size, int):
1072                raise TypeError("The size option must be an integer")
1073            if size > 0:
1074
1075                def chunks():
1076                    while True:
1077                        buffer = read(size)
1078                        yield buffer
1079                        if not buffer or len(buffer) < size:
1080                            break
1081
1082            else:
1083
1084                def chunks():
1085                    yield read()
1086
1087        if not table or not isinstance(table, basestring):
1088            raise TypeError("Need a table to copy to")
1089        if table.lower().startswith('select'):
1090                raise ValueError("Must specify a table, not a query")
1091        else:
1092            table = '"%s"' % (table,)
1093        operation = ['copy %s' % (table,)]
1094        options = []
1095        params = []
1096        if format is not None:
1097            if not isinstance(format, basestring):
1098                raise TypeError("The frmat option must be be a string")
1099            if format not in ('text', 'csv', 'binary'):
1100                raise ValueError("Invalid format")
1101            options.append('format %s' % (format,))
1102        if sep is not None:
1103            if not isinstance(sep, basestring):
1104                raise TypeError("The sep option must be a string")
1105            if format == 'binary':
1106                raise ValueError(
1107                    "The sep option is not allowed with binary format")
1108            if len(sep) != 1:
1109                raise ValueError(
1110                    "The sep option must be a single one-byte character")
1111            options.append('delimiter %s')
1112            params.append(sep)
1113        if null is not None:
1114            if not isinstance(null, basestring):
1115                raise TypeError("The null option must be a string")
1116            options.append('null %s')
1117            params.append(null)
1118        if columns:
1119            if not isinstance(columns, basestring):
1120                columns = ','.join('"%s"' % (col,) for col in columns)
1121            operation.append('(%s)' % (columns,))
1122        operation.append("from stdin")
1123        if options:
1124            operation.append('(%s)' % ','.join(options))
1125        operation = ' '.join(operation)
1126
1127        putdata = self._src.putdata
1128        self.execute(operation, params)
1129
1130        try:
1131            for chunk in chunks():
1132                putdata(chunk)
1133        except BaseException as error:
1134            self.rowcount = -1
1135            # the following call will re-raise the error
1136            putdata(error)
1137        else:
1138            self.rowcount = putdata(None)
1139
1140        # return the cursor object, so you can chain operations
1141        return self
1142
1143    def copy_to(self, stream, table,
1144            format=None, sep=None, null=None, decode=None, columns=None):
1145        """Copy data from the specified table to an output stream.
1146
1147        The output stream can be a file-like object with a write() method or
1148        it can also be None, in which case the method will return a generator
1149        yielding a row on each iteration.
1150
1151        Output will be returned as byte strings unless you set decode to true.
1152
1153        Note that you can also use a select query instead of the table name.
1154
1155        The format must be text, csv or binary. The sep option sets the
1156        column separator (delimiter) used in the non binary formats.
1157        The null option sets the textual representation of NULL in the output.
1158
1159        The copy operation can be restricted to a subset of columns. If no
1160        columns are specified, all of them will be copied.
1161        """
1162        binary_format = format == 'binary'
1163        if stream is not None:
1164            try:
1165                write = stream.write
1166            except AttributeError:
1167                raise TypeError("Need an output stream to copy to")
1168        if not table or not isinstance(table, basestring):
1169            raise TypeError("Need a table to copy to")
1170        if table.lower().startswith('select'):
1171            if columns:
1172                raise ValueError("Columns must be specified in the query")
1173            table = '(%s)' % (table,)
1174        else:
1175            table = '"%s"' % (table,)
1176        operation = ['copy %s' % (table,)]
1177        options = []
1178        params = []
1179        if format is not None:
1180            if not isinstance(format, basestring):
1181                raise TypeError("The format option must be a string")
1182            if format not in ('text', 'csv', 'binary'):
1183                raise ValueError("Invalid format")
1184            options.append('format %s' % (format,))
1185        if sep is not None:
1186            if not isinstance(sep, basestring):
1187                raise TypeError("The sep option must be a string")
1188            if binary_format:
1189                raise ValueError(
1190                    "The sep option is not allowed with binary format")
1191            if len(sep) != 1:
1192                raise ValueError(
1193                    "The sep option must be a single one-byte character")
1194            options.append('delimiter %s')
1195            params.append(sep)
1196        if null is not None:
1197            if not isinstance(null, basestring):
1198                raise TypeError("The null option must be a string")
1199            options.append('null %s')
1200            params.append(null)
1201        if decode is None:
1202            if format == 'binary':
1203                decode = False
1204            else:
1205                decode = str is unicode
1206        else:
1207            if not isinstance(decode, (int, bool)):
1208                raise TypeError("The decode option must be a boolean")
1209            if decode and binary_format:
1210                raise ValueError(
1211                    "The decode option is not allowed with binary format")
1212        if columns:
1213            if not isinstance(columns, basestring):
1214                columns = ','.join('"%s"' % (col,) for col in columns)
1215            operation.append('(%s)' % (columns,))
1216
1217        operation.append("to stdout")
1218        if options:
1219            operation.append('(%s)' % ','.join(options))
1220        operation = ' '.join(operation)
1221
1222        getdata = self._src.getdata
1223        self.execute(operation, params)
1224
1225        def copy():
1226            self.rowcount = 0
1227            while True:
1228                row = getdata(decode)
1229                if isinstance(row, int):
1230                    if self.rowcount != row:
1231                        self.rowcount = row
1232                    break
1233                self.rowcount += 1
1234                yield row
1235
1236        if stream is None:
1237            # no input stream, return the generator
1238            return copy()
1239
1240        # write the rows to the file-like input stream
1241        for row in copy():
1242            write(row)
1243
1244        # return the cursor object, so you can chain operations
1245        return self
1246
1247    def __next__(self):
1248        """Return the next row (support for the iteration protocol)."""
1249        res = self.fetchone()
1250        if res is None:
1251            raise StopIteration
1252        return res
1253
1254    # Note that since Python 2.6 the iterator protocol uses __next()__
1255    # instead of next(), we keep it only for backward compatibility of pgdb.
1256    next = __next__
1257
1258    @staticmethod
1259    def nextset():
1260        """Not supported."""
1261        raise NotSupportedError("The nextset() method is not supported")
1262
1263    @staticmethod
1264    def setinputsizes(sizes):
1265        """Not supported."""
1266        pass  # unsupported, but silently passed
1267
1268    @staticmethod
1269    def setoutputsize(size, column=0):
1270        """Not supported."""
1271        pass  # unsupported, but silently passed
1272
1273    @staticmethod
1274    def row_factory(row):
1275        """Process rows before they are returned.
1276
1277        You can overwrite this statically with a custom row factory, or
1278        you can build a row factory dynamically with build_row_factory().
1279
1280        For example, you can create a Cursor class that returns rows as
1281        Python dictionaries like this:
1282
1283            class DictCursor(pgdb.Cursor):
1284
1285                def row_factory(self, row):
1286                    return {desc[0]: value
1287                        for desc, value in zip(self.description, row)}
1288
1289            cur = DictCursor(con)  # get one DictCursor instance or
1290            con.cursor_type = DictCursor  # always use DictCursor instances
1291        """
1292        raise NotImplementedError
1293
1294    def build_row_factory(self):
1295        """Build a row factory based on the current description.
1296
1297        This implementation builds a row factory for creating named tuples.
1298        You can overwrite this method if you want to dynamically create
1299        different row factories whenever the column description changes.
1300        """
1301        colnames = self.colnames
1302        if colnames:
1303            try:
1304                try:
1305                    return namedtuple('Row', colnames, rename=True)._make
1306                except TypeError:  # Python 2.6 and 3.0 do not support rename
1307                    colnames = [v if v.isalnum() else 'column_%d' % n
1308                             for n, v in enumerate(colnames)]
1309                    return namedtuple('Row', colnames)._make
1310            except ValueError:  # there is still a problem with the field names
1311                colnames = ['column_%d' % n for n in range(len(colnames))]
1312                return namedtuple('Row', colnames)._make
1313
1314
1315CursorDescription = namedtuple('CursorDescription',
1316    ['name', 'type_code', 'display_size', 'internal_size',
1317     'precision', 'scale', 'null_ok'])
1318
1319
1320### Connection Objects
1321
1322class Connection(object):
1323    """Connection object."""
1324
1325    # expose the exceptions as attributes on the connection object
1326    Error = Error
1327    Warning = Warning
1328    InterfaceError = InterfaceError
1329    DatabaseError = DatabaseError
1330    InternalError = InternalError
1331    OperationalError = OperationalError
1332    ProgrammingError = ProgrammingError
1333    IntegrityError = IntegrityError
1334    DataError = DataError
1335    NotSupportedError = NotSupportedError
1336
1337    def __init__(self, cnx):
1338        """Create a database connection object."""
1339        self._cnx = cnx  # connection
1340        self._tnx = False  # transaction state
1341        self.type_cache = TypeCache(cnx)
1342        self.cursor_type = Cursor
1343        try:
1344            self._cnx.source()
1345        except Exception:
1346            raise _op_error("Invalid connection")
1347
1348    def __enter__(self):
1349        """Enter the runtime context for the connection object.
1350
1351        The runtime context can be used for running transactions.
1352        """
1353        return self
1354
1355    def __exit__(self, et, ev, tb):
1356        """Exit the runtime context for the connection object.
1357
1358        This does not close the connection, but it ends a transaction.
1359        """
1360        if et is None and ev is None and tb is None:
1361            self.commit()
1362        else:
1363            self.rollback()
1364
1365    def close(self):
1366        """Close the connection object."""
1367        if self._cnx:
1368            if self._tnx:
1369                try:
1370                    self.rollback()
1371                except DatabaseError:
1372                    pass
1373            self._cnx.close()
1374            self._cnx = None
1375        else:
1376            raise _op_error("Connection has been closed")
1377
1378    @property
1379    def closed(self):
1380        """Check whether the connection has been closed or is broken."""
1381        try:
1382            return not self._cnx or self._cnx.status != 1
1383        except TypeError:
1384            return True
1385
1386    def commit(self):
1387        """Commit any pending transaction to the database."""
1388        if self._cnx:
1389            if self._tnx:
1390                self._tnx = False
1391                try:
1392                    self._cnx.source().execute("COMMIT")
1393                except DatabaseError:
1394                    raise
1395                except Exception:
1396                    raise _op_error("Can't commit")
1397        else:
1398            raise _op_error("Connection has been closed")
1399
1400    def rollback(self):
1401        """Roll back to the start of any pending transaction."""
1402        if self._cnx:
1403            if self._tnx:
1404                self._tnx = False
1405                try:
1406                    self._cnx.source().execute("ROLLBACK")
1407                except DatabaseError:
1408                    raise
1409                except Exception:
1410                    raise _op_error("Can't rollback")
1411        else:
1412            raise _op_error("Connection has been closed")
1413
1414    def cursor(self):
1415        """Return a new cursor object using the connection."""
1416        if self._cnx:
1417            try:
1418                return self.cursor_type(self)
1419            except Exception:
1420                raise _op_error("Invalid connection")
1421        else:
1422            raise _op_error("Connection has been closed")
1423
1424    if shortcutmethods:  # otherwise do not implement and document this
1425
1426        def execute(self, operation, params=None):
1427            """Shortcut method to run an operation on an implicit cursor."""
1428            cursor = self.cursor()
1429            cursor.execute(operation, params)
1430            return cursor
1431
1432        def executemany(self, operation, param_seq):
1433            """Shortcut method to run an operation against a sequence."""
1434            cursor = self.cursor()
1435            cursor.executemany(operation, param_seq)
1436            return cursor
1437
1438
1439### Module Interface
1440
1441_connect = connect
1442
1443def connect(dsn=None,
1444        user=None, password=None,
1445        host=None, database=None):
1446    """Connect to a database."""
1447    # first get params from DSN
1448    dbport = -1
1449    dbhost = ""
1450    dbbase = ""
1451    dbuser = ""
1452    dbpasswd = ""
1453    dbopt = ""
1454    try:
1455        params = dsn.split(":")
1456        dbhost = params[0]
1457        dbbase = params[1]
1458        dbuser = params[2]
1459        dbpasswd = params[3]
1460        dbopt = params[4]
1461    except (AttributeError, IndexError, TypeError):
1462        pass
1463
1464    # override if necessary
1465    if user is not None:
1466        dbuser = user
1467    if password is not None:
1468        dbpasswd = password
1469    if database is not None:
1470        dbbase = database
1471    if host is not None:
1472        try:
1473            params = host.split(":")
1474            dbhost = params[0]
1475            dbport = int(params[1])
1476        except (AttributeError, IndexError, TypeError, ValueError):
1477            pass
1478
1479    # empty host is localhost
1480    if dbhost == "":
1481        dbhost = None
1482    if dbuser == "":
1483        dbuser = None
1484
1485    # open the connection
1486    cnx = _connect(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
1487    return Connection(cnx)
1488
1489
1490### Types Handling
1491
1492class Type(frozenset):
1493    """Type class for a couple of PostgreSQL data types.
1494
1495    PostgreSQL is object-oriented: types are dynamic.
1496    We must thus use type names as internal type codes.
1497    """
1498
1499    def __new__(cls, values):
1500        if isinstance(values, basestring):
1501            values = values.split()
1502        return super(Type, cls).__new__(cls, values)
1503
1504    def __eq__(self, other):
1505        if isinstance(other, basestring):
1506            if other.startswith('_'):
1507                other = other[1:]
1508            return other in self
1509        else:
1510            return super(Type, self).__eq__(other)
1511
1512    def __ne__(self, other):
1513        if isinstance(other, basestring):
1514            if other.startswith('_'):
1515                other = other[1:]
1516            return other not in self
1517        else:
1518            return super(Type, self).__ne__(other)
1519
1520
1521class ArrayType:
1522    """Type class for PostgreSQL array types."""
1523
1524    def __eq__(self, other):
1525        if isinstance(other, basestring):
1526            return other.startswith('_')
1527        else:
1528            return isinstance(other, ArrayType)
1529
1530    def __ne__(self, other):
1531        if isinstance(other, basestring):
1532            return not other.startswith('_')
1533        else:
1534            return not isinstance(other, ArrayType)
1535
1536
1537class RecordType:
1538    """Type class for PostgreSQL record types."""
1539
1540    def __eq__(self, other):
1541        if isinstance(other, TypeCode):
1542            return other.type == 'c'
1543        elif isinstance(other, basestring):
1544            return other == 'record'
1545        else:
1546            return isinstance(other, RecordType)
1547
1548    def __ne__(self, other):
1549        if isinstance(other, TypeCode):
1550            return other.type != 'c'
1551        elif isinstance(other, basestring):
1552            return other != 'record'
1553        else:
1554            return not isinstance(other, RecordType)
1555
1556
1557# Mandatory type objects defined by DB-API 2 specs:
1558
1559STRING = Type('char bpchar name text varchar')
1560BINARY = Type('bytea')
1561NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1562DATETIME = Type('date time timetz timestamp timestamptz interval'
1563    ' abstime reltime')  # these are very old
1564ROWID = Type('oid')
1565
1566
1567# Additional type objects (more specific):
1568
1569BOOL = Type('bool')
1570SMALLINT = Type('int2')
1571INTEGER = Type('int2 int4 int8 serial')
1572LONG = Type('int8')
1573FLOAT = Type('float4 float8')
1574NUMERIC = Type('numeric')
1575MONEY = Type('money')
1576DATE = Type('date')
1577TIME = Type('time timetz')
1578TIMESTAMP = Type('timestamp timestamptz')
1579INTERVAL = Type('interval')
1580UUID = Type('uuid')
1581HSTORE = Type('hstore')
1582JSON = Type('json jsonb')
1583
1584# Type object for arrays (also equate to their base types):
1585
1586ARRAY = ArrayType()
1587
1588# Type object for records (encompassing all composite types):
1589
1590RECORD = RecordType()
1591
1592
1593# Mandatory type helpers defined by DB-API 2 specs:
1594
1595def Date(year, month, day):
1596    """Construct an object holding a date value."""
1597    return date(year, month, day)
1598
1599
1600def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None):
1601    """Construct an object holding a time value."""
1602    return time(hour, minute, second, microsecond, tzinfo)
1603
1604
1605def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0,
1606        tzinfo=None):
1607    """Construct an object holding a time stamp value."""
1608    return datetime(year, month, day, hour, minute, second, microsecond, tzinfo)
1609
1610
1611def DateFromTicks(ticks):
1612    """Construct an object holding a date value from the given ticks value."""
1613    return Date(*localtime(ticks)[:3])
1614
1615
1616def TimeFromTicks(ticks):
1617    """Construct an object holding a time value from the given ticks value."""
1618    return Time(*localtime(ticks)[3:6])
1619
1620
1621def TimestampFromTicks(ticks):
1622    """Construct an object holding a time stamp from the given ticks value."""
1623    return Timestamp(*localtime(ticks)[:6])
1624
1625
1626class Binary(bytes):
1627    """Construct an object capable of holding a binary (long) string value."""
1628
1629
1630# Additional type helpers for PyGreSQL:
1631
1632def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0):
1633    """Construct an object holding a time inverval value."""
1634    return timedelta(days, hours=hours, minutes=minutes, seconds=seconds,
1635        microseconds=microseconds)
1636
1637
1638Uuid = Uuid  # Construct an object holding a UUID value
1639
1640
1641class Hstore(dict):
1642    """Wrapper class for marking hstore values."""
1643
1644    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
1645    _re_escape = regex(r'(["\\])')
1646
1647    @classmethod
1648    def _quote(cls, s):
1649        if s is None:
1650            return 'NULL'
1651        if not s:
1652            return '""'
1653        quote = cls._re_quote.search(s)
1654        s = cls._re_escape.sub(r'\\\1', s)
1655        if quote:
1656            s = '"%s"' % s
1657        return s
1658
1659    def __str__(self):
1660        q = self._quote
1661        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
1662
1663
1664class Json:
1665    """Construct a wrapper for holding an object serializable to JSON."""
1666
1667    def __init__(self, obj, encode=None):
1668        self.obj = obj
1669        self.encode = encode or jsonencode
1670
1671    def __str__(self):
1672        obj = self.obj
1673        if isinstance(obj, basestring):
1674            return obj
1675        return self.encode(obj)
1676
1677
1678class Literal:
1679    """Construct a wrapper for holding a literal SQL string."""
1680
1681    def __init__(self, sql):
1682        self.sql = sql
1683
1684    def __str__(self):
1685        return self.sql
1686
1687    __pg_repr__ = __str__
1688
1689# If run as script, print some information:
1690
1691if __name__ == '__main__':
1692    print('PyGreSQL version', version)
1693    print('')
1694    print(__doc__)
Note: See TracBrowser for help on using the repository browser.