source: trunk/pgdb.py @ 893

Last change on this file since 893 was 893, checked in by cito, 3 years ago

Add a way to switch off array casting in pgdb

  • Property svn:keywords set to Id
File size: 56.4 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 893 2016-09-21 14:44:59Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function, division
67
68from _pg import *
69
70__version__ = version
71
72from datetime import date, time, datetime, timedelta, tzinfo
73from time import localtime
74from decimal import Decimal
75from uuid import UUID as Uuid
76from math import isnan, isinf
77from collections import namedtuple
78from functools import partial
79from re import compile as regex
80from json import loads as jsondecode, dumps as jsonencode
81
82try:
83    long
84except NameError:  # Python >= 3.0
85    long = int
86
87try:
88    unicode
89except NameError:  # Python >= 3.0
90    unicode = str
91
92try:
93    basestring
94except NameError:  # Python >= 3.0
95    basestring = (str, bytes)
96
97from collections import Iterable
98
99
100### Module Constants
101
102# compliant with DB API 2.0
103apilevel = '2.0'
104
105# module may be shared, but not connections
106threadsafety = 1
107
108# this module use extended python format codes
109paramstyle = 'pyformat'
110
111# shortcut methods have been excluded from DB API 2 and
112# are not recommended by the DB SIG, but they can be handy
113shortcutmethods = 1
114
115
116### Internal Type Handling
117
118try:
119    from inspect import signature
120except ImportError:  # Python < 3.3
121    from inspect import getargspec
122
123    def get_args(func):
124        return getargspec(func).args
125else:
126
127    def get_args(func):
128        return list(signature(func).parameters)
129
130try:
131    from datetime import timezone
132except ImportError:  # Python < 3.2
133
134    class timezone(tzinfo):
135        """Simple timezone implementation."""
136
137        def __init__(self, offset, name=None):
138            self.offset = offset
139            if not name:
140                minutes = self.offset.days * 1440 + self.offset.seconds // 60
141                if minutes < 0:
142                    hours, minutes = divmod(-minutes, 60)
143                    hours = -hours
144                else:
145                    hours, minutes = divmod(minutes, 60)
146                name = 'UTC%+03d:%02d' % (hours, minutes)
147            self.name = name
148
149        def utcoffset(self, dt):
150            return self.offset
151
152        def tzname(self, dt):
153            return self.name
154
155        def dst(self, dt):
156            return None
157
158    timezone.utc = timezone(timedelta(0), 'UTC')
159
160    _has_timezone = False
161else:
162    _has_timezone = True
163
164# time zones used in Postgres timestamptz output
165_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
166    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
167    UCT='+0000', UTC='+0000', WET='+0000')
168
169
170def _timezone_as_offset(tz):
171    if tz.startswith(('+', '-')):
172        if len(tz) < 5:
173            return tz + '00'
174        return tz.replace(':', '')
175    return _timezones.get(tz, '+0000')
176
177
178def _get_timezone(tz):
179    tz = _timezone_as_offset(tz)
180    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
181    if tz[0] == '-':
182        minutes = -minutes
183    return timezone(timedelta(minutes=minutes), tz)
184
185
186def decimal_type(decimal_type=None):
187    """Get or set global type to be used for decimal values.
188
189    Note that connections cache cast functions. To be sure a global change
190    is picked up by a running connection, call con.type_cache.reset_typecast().
191    """
192    global Decimal
193    if decimal_type is not None:
194        Decimal = decimal_type
195        set_typecast('numeric', decimal_type)
196    return Decimal
197
198
199def cast_bool(value):
200    """Cast boolean value in database format to bool."""
201    if value:
202        return value[0] in ('t', 'T')
203
204
205def cast_money(value):
206    """Cast money value in database format to Decimal."""
207    if value:
208        value = value.replace('(', '-')
209        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
210
211
212def cast_int2vector(value):
213    """Cast an int2vector value."""
214    return [int(v) for v in value.split()]
215
216
217def cast_date(value, connection):
218    """Cast a date value."""
219    # The output format depends on the server setting DateStyle.  The default
220    # setting ISO and the setting for German are actually unambiguous.  The
221    # order of days and months in the other two settings is however ambiguous,
222    # so at least here we need to consult the setting to properly parse values.
223    if value == '-infinity':
224        return date.min
225    if value == 'infinity':
226        return date.max
227    value = value.split()
228    if value[-1] == 'BC':
229        return date.min
230    value = value[0]
231    if len(value) > 10:
232        return date.max
233    fmt = connection.date_format()
234    return datetime.strptime(value, fmt).date()
235
236
237def cast_time(value):
238    """Cast a time value."""
239    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
240    return datetime.strptime(value, fmt).time()
241
242
243_re_timezone = regex('(.*)([+-].*)')
244
245
246def cast_timetz(value):
247    """Cast a timetz value."""
248    tz = _re_timezone.match(value)
249    if tz:
250        value, tz = tz.groups()
251    else:
252        tz = '+0000'
253    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
254    if _has_timezone:
255        value += _timezone_as_offset(tz)
256        fmt += '%z'
257        return datetime.strptime(value, fmt).timetz()
258    return datetime.strptime(value, fmt).timetz().replace(
259        tzinfo=_get_timezone(tz))
260
261
262def cast_timestamp(value, connection):
263    """Cast a timestamp value."""
264    if value == '-infinity':
265        return datetime.min
266    if value == 'infinity':
267        return datetime.max
268    value = value.split()
269    if value[-1] == 'BC':
270        return datetime.min
271    fmt = connection.date_format()
272    if fmt.endswith('-%Y') and len(value) > 2:
273        value = value[1:5]
274        if len(value[3]) > 4:
275            return datetime.max
276        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
277            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
278    else:
279        if len(value[0]) > 10:
280            return datetime.max
281        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
282    return datetime.strptime(' '.join(value), ' '.join(fmt))
283
284
285def cast_timestamptz(value, connection):
286    """Cast a timestamptz value."""
287    if value == '-infinity':
288        return datetime.min
289    if value == 'infinity':
290        return datetime.max
291    value = value.split()
292    if value[-1] == 'BC':
293        return datetime.min
294    fmt = connection.date_format()
295    if fmt.endswith('-%Y') and len(value) > 2:
296        value = value[1:]
297        if len(value[3]) > 4:
298            return datetime.max
299        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
300            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
301        value, tz = value[:-1], value[-1]
302    else:
303        if fmt.startswith('%Y-'):
304            tz = _re_timezone.match(value[1])
305            if tz:
306                value[1], tz = tz.groups()
307            else:
308                tz = '+0000'
309        else:
310            value, tz = value[:-1], value[-1]
311        if len(value[0]) > 10:
312            return datetime.max
313        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
314    if _has_timezone:
315        value.append(_timezone_as_offset(tz))
316        fmt.append('%z')
317        return datetime.strptime(' '.join(value), ' '.join(fmt))
318    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
319        tzinfo=_get_timezone(tz))
320
321
322_re_interval_sql_standard = regex(
323    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
324    '(?:([+-]?[0-9]+)(?!:) ?)?'
325    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
326
327_re_interval_postgres = regex(
328    '(?:([+-]?[0-9]+) ?years? ?)?'
329    '(?:([+-]?[0-9]+) ?mons? ?)?'
330    '(?:([+-]?[0-9]+) ?days? ?)?'
331    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
332
333_re_interval_postgres_verbose = regex(
334    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
335    '(?:([+-]?[0-9]+) ?mons? ?)?'
336    '(?:([+-]?[0-9]+) ?days? ?)?'
337    '(?:([+-]?[0-9]+) ?hours? ?)?'
338    '(?:([+-]?[0-9]+) ?mins? ?)?'
339    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
340
341_re_interval_iso_8601 = regex(
342    'P(?:([+-]?[0-9]+)Y)?'
343    '(?:([+-]?[0-9]+)M)?'
344    '(?:([+-]?[0-9]+)D)?'
345    '(?:T(?:([+-]?[0-9]+)H)?'
346    '(?:([+-]?[0-9]+)M)?'
347    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
348
349
350def cast_interval(value):
351    """Cast an interval value."""
352    # The output format depends on the server setting IntervalStyle, but it's
353    # not necessary to consult this setting to parse it.  It's faster to just
354    # check all possible formats, and there is no ambiguity here.
355    m = _re_interval_iso_8601.match(value)
356    if m:
357        m = [d or '0' for d in m.groups()]
358        secs_ago = m.pop(5) == '-'
359        m = [int(d) for d in m]
360        years, mons, days, hours, mins, secs, usecs = m
361        if secs_ago:
362            secs = -secs
363            usecs = -usecs
364    else:
365        m = _re_interval_postgres_verbose.match(value)
366        if m:
367            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
368            secs_ago = m.pop(5) == '-'
369            m = [-int(d) for d in m] if ago else [int(d) for d in m]
370            years, mons, days, hours, mins, secs, usecs = m
371            if secs_ago:
372                secs = - secs
373                usecs = -usecs
374        else:
375            m = _re_interval_postgres.match(value)
376            if m and any(m.groups()):
377                m = [d or '0' for d in m.groups()]
378                hours_ago = m.pop(3) == '-'
379                m = [int(d) for d in m]
380                years, mons, days, hours, mins, secs, usecs = m
381                if hours_ago:
382                    hours = -hours
383                    mins = -mins
384                    secs = -secs
385                    usecs = -usecs
386            else:
387                m = _re_interval_sql_standard.match(value)
388                if m and any(m.groups()):
389                    m = [d or '0' for d in m.groups()]
390                    years_ago = m.pop(0) == '-'
391                    hours_ago = m.pop(3) == '-'
392                    m = [int(d) for d in m]
393                    years, mons, days, hours, mins, secs, usecs = m
394                    if years_ago:
395                        years = -years
396                        mons = -mons
397                    if hours_ago:
398                        hours = -hours
399                        mins = -mins
400                        secs = -secs
401                        usecs = -usecs
402                else:
403                    raise ValueError('Cannot parse interval: %s' % value)
404    days += 365 * years + 30 * mons
405    return timedelta(days=days, hours=hours, minutes=mins,
406        seconds=secs, microseconds=usecs)
407
408
409class Typecasts(dict):
410    """Dictionary mapping database types to typecast functions.
411
412    The cast functions get passed the string representation of a value in
413    the database which they need to convert to a Python object.  The
414    passed string will never be None since NULL values are already be
415    handled before the cast function is called.
416    """
417
418    # the default cast functions
419    # (str functions are ignored but have been added for faster access)
420    defaults = {'char': str, 'bpchar': str, 'name': str,
421        'text': str, 'varchar': str,
422        'bool': cast_bool, 'bytea': unescape_bytea,
423        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
424        'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode,
425        'float4': float, 'float8': float,
426        'numeric': Decimal, 'money': cast_money,
427        'date': cast_date, 'interval': cast_interval,
428        'time': cast_time, 'timetz': cast_timetz,
429        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
430        'int2vector': cast_int2vector, 'uuid': Uuid,
431        'anyarray': cast_array, 'record': cast_record}
432
433    connection = None  # will be set in local connection specific instances
434
435    def __missing__(self, typ):
436        """Create a cast function if it is not cached.
437
438        Note that this class never raises a KeyError,
439        but returns None when no special cast function exists.
440        """
441        if not isinstance(typ, str):
442            raise TypeError('Invalid type: %s' % typ)
443        cast = self.defaults.get(typ)
444        if cast:
445            # store default for faster access
446            cast = self._add_connection(cast)
447            self[typ] = cast
448        elif typ.startswith('_'):
449            # create array cast
450            base_cast = self[typ[1:]]
451            cast = self.create_array_cast(base_cast)
452            if base_cast:
453                # store only if base type exists
454                self[typ] = cast
455        return cast
456
457    @staticmethod
458    def _needs_connection(func):
459        """Check if a typecast function needs a connection argument."""
460        try:
461            args = get_args(func)
462        except (TypeError, ValueError):
463            return False
464        else:
465            return 'connection' in args[1:]
466
467    def _add_connection(self, cast):
468        """Add a connection argument to the typecast function if necessary."""
469        if not self.connection or not self._needs_connection(cast):
470            return cast
471        return partial(cast, connection=self.connection)
472
473    def get(self, typ, default=None):
474        """Get the typecast function for the given database type."""
475        return self[typ] or default
476
477    def set(self, typ, cast):
478        """Set a typecast function for the specified database type(s)."""
479        if isinstance(typ, basestring):
480            typ = [typ]
481        if cast is None:
482            for t in typ:
483                self.pop(t, None)
484                self.pop('_%s' % t, None)
485        else:
486            if not callable(cast):
487                raise TypeError("Cast parameter must be callable")
488            for t in typ:
489                self[t] = self._add_connection(cast)
490                self.pop('_%s' % t, None)
491
492    def reset(self, typ=None):
493        """Reset the typecasts for the specified type(s) to their defaults.
494
495        When no type is specified, all typecasts will be reset.
496        """
497        defaults = self.defaults
498        if typ is None:
499            self.clear()
500            self.update(defaults)
501        else:
502            if isinstance(typ, basestring):
503                typ = [typ]
504            for t in typ:
505                cast = defaults.get(t)
506                if cast:
507                    self[t] = self._add_connection(cast)
508                    t = '_%s' % t
509                    cast = defaults.get(t)
510                    if cast:
511                        self[t] = self._add_connection(cast)
512                    else:
513                        self.pop(t, None)
514                else:
515                    self.pop(t, None)
516                    self.pop('_%s' % t, None)
517
518    def create_array_cast(self, basecast):
519        """Create an array typecast for the given base cast."""
520        cast_array = self['anyarray']
521        def cast(v):
522            return cast_array(v, basecast)
523        return cast
524
525    def create_record_cast(self, name, fields, casts):
526        """Create a named record typecast for the given fields and casts."""
527        cast_record = self['record']
528        record = namedtuple(name, fields)
529        def cast(v):
530            return record(*cast_record(v, casts))
531        return cast
532
533
534_typecasts = Typecasts()  # this is the global typecast dictionary
535
536
537def get_typecast(typ):
538    """Get the global typecast function for the given database type(s)."""
539    return _typecasts.get(typ)
540
541
542def set_typecast(typ, cast):
543    """Set a global typecast function for the given database type(s).
544
545    Note that connections cache cast functions. To be sure a global change
546    is picked up by a running connection, call con.type_cache.reset_typecast().
547    """
548    _typecasts.set(typ, cast)
549
550
551def reset_typecast(typ=None):
552    """Reset the global typecasts for the given type(s) to their default.
553
554    When no type is specified, all typecasts will be reset.
555
556    Note that connections cache cast functions. To be sure a global change
557    is picked up by a running connection, call con.type_cache.reset_typecast().
558    """
559    _typecasts.reset(typ)
560
561
562class LocalTypecasts(Typecasts):
563    """Map typecasts, including local composite types, to cast functions."""
564
565    defaults = _typecasts
566
567    connection = None  # will be set in a connection specific instance
568
569    def __missing__(self, typ):
570        """Create a cast function if it is not cached."""
571        if typ.startswith('_'):
572            base_cast = self[typ[1:]]
573            cast = self.create_array_cast(base_cast)
574            if base_cast:
575                self[typ] = cast
576        else:
577            cast = self.defaults.get(typ)
578            if cast:
579                cast = self._add_connection(cast)
580                self[typ] = cast
581            else:
582                fields = self.get_fields(typ)
583                if fields:
584                    casts = [self[field.type] for field in fields]
585                    fields = [field.name for field in fields]
586                    cast = self.create_record_cast(typ, fields, casts)
587                    self[typ] = cast
588        return cast
589
590    def get_fields(self, typ):
591        """Return the fields for the given record type.
592
593        This method will be replaced with a method that looks up the fields
594        using the type cache of the connection.
595        """
596        return []
597
598
599class TypeCode(str):
600    """Class representing the type_code used by the DB-API 2.0.
601
602    TypeCode objects are strings equal to the PostgreSQL type name,
603    but carry some additional information.
604    """
605
606    @classmethod
607    def create(cls, oid, name, len, type, category, delim, relid):
608        """Create a type code for a PostgreSQL data type."""
609        self = cls(name)
610        self.oid = oid
611        self.len = len
612        self.type = type
613        self.category = category
614        self.delim = delim
615        self.relid = relid
616        return self
617
618FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
619
620
621class TypeCache(dict):
622    """Cache for database types.
623
624    This cache maps type OIDs and names to TypeCode strings containing
625    important information on the associated database type.
626    """
627
628    def __init__(self, cnx):
629        """Initialize type cache for connection."""
630        super(TypeCache, self).__init__()
631        self._escape_string = cnx.escape_string
632        self._src = cnx.source()
633        self._typecasts = LocalTypecasts()
634        self._typecasts.get_fields = self.get_fields
635        self._typecasts.connection = cnx
636        if cnx.server_version < 80400:
637            # older remote databases (not officially supported)
638            self._query_pg_type = ("SELECT oid, typname,"
639                " typlen, typtype, null as typcategory, typdelim, typrelid"
640                " FROM pg_type WHERE oid=%s")
641        else:
642            self._query_pg_type = ("SELECT oid, typname,"
643                " typlen, typtype, typcategory, typdelim, typrelid"
644                " FROM pg_type WHERE oid=%s")
645
646    def __missing__(self, key):
647        """Get the type info from the database if it is not cached."""
648        if isinstance(key, int):
649            oid = key
650        else:
651            if '.' not in key and '"' not in key:
652                key = '"%s"' % (key,)
653            oid = "'%s'::regtype" % (self._escape_string(key),)
654        try:
655            self._src.execute(self._query_pg_type % (oid,))
656        except ProgrammingError:
657            res = None
658        else:
659            res = self._src.fetch(1)
660        if not res:
661            raise KeyError('Type %s could not be found' % (key,))
662        res = res[0]
663        type_code = TypeCode.create(int(res[0]), res[1],
664            int(res[2]), res[3], res[4], res[5], int(res[6]))
665        self[type_code.oid] = self[str(type_code)] = type_code
666        return type_code
667
668    def get(self, key, default=None):
669        """Get the type even if it is not cached."""
670        try:
671            return self[key]
672        except KeyError:
673            return default
674
675    def get_fields(self, typ):
676        """Get the names and types of the fields of composite types."""
677        if not isinstance(typ, TypeCode):
678            typ = self.get(typ)
679            if not typ:
680                return None
681        if not typ.relid:
682            return None  # this type is not composite
683        self._src.execute("SELECT attname, atttypid"
684            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
685            " AND NOT attisdropped ORDER BY attnum" % (typ.relid,))
686        return [FieldInfo(name, self.get(int(oid)))
687            for name, oid in self._src.fetch(-1)]
688
689    def get_typecast(self, typ):
690        """Get the typecast function for the given database type."""
691        return self._typecasts.get(typ)
692
693    def set_typecast(self, typ, cast):
694        """Set a typecast function for the specified database type(s)."""
695        self._typecasts.set(typ, cast)
696
697    def reset_typecast(self, typ=None):
698        """Reset the typecast function for the specified database type(s)."""
699        self._typecasts.reset(typ)
700
701    def typecast(self, value, typ):
702        """Cast the given value according to the given database type."""
703        if value is None:
704            # for NULL values, no typecast is necessary
705            return None
706        cast = self.get_typecast(typ)
707        if not cast or cast is str:
708            # no typecast is necessary
709            return value
710        return cast(value)
711
712
713class _quotedict(dict):
714    """Dictionary with auto quoting of its items.
715
716    The quote attribute must be set to the desired quote function.
717    """
718
719    def __getitem__(self, key):
720        return self.quote(super(_quotedict, self).__getitem__(key))
721
722
723### Error messages
724
725def _db_error(msg, cls=DatabaseError):
726    """Return DatabaseError with empty sqlstate attribute."""
727    error = cls(msg)
728    error.sqlstate = None
729    return error
730
731
732def _op_error(msg):
733    """Return OperationalError."""
734    return _db_error(msg, OperationalError)
735
736
737### Cursor Object
738
739class Cursor(object):
740    """Cursor object."""
741
742    def __init__(self, dbcnx):
743        """Create a cursor object for the database connection."""
744        self.connection = self._dbcnx = dbcnx
745        self._cnx = dbcnx._cnx
746        self.type_cache = dbcnx.type_cache
747        self._src = self._cnx.source()
748        # the official attribute for describing the result columns
749        self._description = None
750        if self.row_factory is Cursor.row_factory:
751            # the row factory needs to be determined dynamically
752            self.row_factory = None
753        else:
754            self.build_row_factory = None
755        self.rowcount = -1
756        self.arraysize = 1
757        self.lastrowid = None
758
759    def __iter__(self):
760        """Make cursor compatible to the iteration protocol."""
761        return self
762
763    def __enter__(self):
764        """Enter the runtime context for the cursor object."""
765        return self
766
767    def __exit__(self, et, ev, tb):
768        """Exit the runtime context for the cursor object."""
769        self.close()
770
771    def _quote(self, value):
772        """Quote value depending on its type."""
773        if value is None:
774            return 'NULL'
775        if isinstance(value, (Hstore, Json)):
776            value = str(value)
777        if isinstance(value, basestring):
778            if isinstance(value, Binary):
779                value = self._cnx.escape_bytea(value)
780                if bytes is not str:  # Python >= 3.0
781                    value = value.decode('ascii')
782            else:
783                value = self._cnx.escape_string(value)
784            return "'%s'" % (value,)
785        if isinstance(value, float):
786            if isinf(value):
787                return "'-Infinity'" if value < 0 else "'Infinity'"
788            if isnan(value):
789                return "'NaN'"
790            return value
791        if isinstance(value, (int, long, Decimal, Literal)):
792            return value
793        if isinstance(value, datetime):
794            if value.tzinfo:
795                return "'%s'::timestamptz" % (value,)
796            return "'%s'::timestamp" % (value,)
797        if isinstance(value, date):
798            return "'%s'::date" % (value,)
799        if isinstance(value, time):
800            if value.tzinfo:
801                return "'%s'::timetz" % (value,)
802            return "'%s'::time" % value
803        if isinstance(value, timedelta):
804            return "'%s'::interval" % (value,)
805        if isinstance(value, Uuid):
806            return "'%s'::uuid" % (value,)
807        if isinstance(value, list):
808            # Quote value as an ARRAY constructor. This is better than using
809            # an array literal because it carries the information that this is
810            # an array and not a string.  One issue with this syntax is that
811            # you need to add an explicit typecast when passing empty arrays.
812            # The ARRAY keyword is actually only necessary at the top level.
813            if not value:  # exception for empty array
814                return "'{}'"
815            q = self._quote
816            try:
817                return 'ARRAY[%s]' % (','.join(str(q(v)) for v in value),)
818            except UnicodeEncodeError:  # Python 2 with non-ascii values
819                return u'ARRAY[%s]' % (','.join(unicode(q(v)) for v in value),)
820        if isinstance(value, tuple):
821            # Quote as a ROW constructor.  This is better than using a record
822            # literal because it carries the information that this is a record
823            # and not a string.  We don't use the keyword ROW in order to make
824            # this usable with the IN syntax as well.  It is only necessary
825            # when the records has a single column which is not really useful.
826            q = self._quote
827            try:
828                return '(%s)' % (','.join(str(q(v)) for v in value),)
829            except UnicodeEncodeError:  # Python 2 with non-ascii values
830                return u'(%s)' % (','.join(unicode(q(v)) for v in value),)
831        try:
832            value = value.__pg_repr__()
833        except AttributeError:
834            raise InterfaceError(
835                'Do not know how to adapt type %s' % (type(value),))
836        if isinstance(value, (tuple, list)):
837            value = self._quote(value)
838        return value
839
840    def _quoteparams(self, string, parameters):
841        """Quote parameters.
842
843        This function works for both mappings and sequences.
844
845        The function should be used even when there are no parameters,
846        so that we have a consistent behavior regarding percent signs.
847        """
848        if not parameters:
849            try:
850                return string % ()  # unescape literal quotes if possible
851            except (TypeError, ValueError):
852                return string  # silently accept unescaped quotes
853        if isinstance(parameters, dict):
854            parameters = _quotedict(parameters)
855            parameters.quote = self._quote
856        else:
857            parameters = tuple(map(self._quote, parameters))
858        return string % parameters
859
860    def _make_description(self, info):
861        """Make the description tuple for the given field info."""
862        name, typ, size, mod = info[1:]
863        type_code = self.type_cache[typ]
864        if mod > 0:
865            mod -= 4
866        if type_code == 'numeric':
867            precision, scale = mod >> 16, mod & 0xffff
868            size = precision
869        else:
870            if not size:
871                size = type_code.size
872            if size == -1:
873                size = mod
874            precision = scale = None
875        return CursorDescription(name, type_code,
876            None, size, precision, scale, None)
877
878    @property
879    def description(self):
880        """Read-only attribute describing the result columns."""
881        descr = self._description
882        if self._description is True:
883            make = self._make_description
884            descr = [make(info) for info in self._src.listinfo()]
885            self._description = descr
886        return descr
887
888    @property
889    def colnames(self):
890        """Unofficial convenience method for getting the column names."""
891        return [d[0] for d in self.description]
892
893    @property
894    def coltypes(self):
895        """Unofficial convenience method for getting the column types."""
896        return [d[1] for d in self.description]
897
898    def close(self):
899        """Close the cursor object."""
900        self._src.close()
901
902    def execute(self, operation, parameters=None):
903        """Prepare and execute a database operation (query or command)."""
904        # The parameters may also be specified as list of tuples to e.g.
905        # insert multiple rows in a single operation, but this kind of
906        # usage is deprecated.  We make several plausibility checks because
907        # tuples can also be passed with the meaning of ROW constructors.
908        if (parameters and isinstance(parameters, list)
909                and len(parameters) > 1
910                and all(isinstance(p, tuple) for p in parameters)
911                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
912            return self.executemany(operation, parameters)
913        else:
914            # not a list of tuples
915            return self.executemany(operation, [parameters])
916
917    def executemany(self, operation, seq_of_parameters):
918        """Prepare operation and execute it against a parameter sequence."""
919        if not seq_of_parameters:
920            # don't do anything without parameters
921            return
922        self._description = None
923        self.rowcount = -1
924        # first try to execute all queries
925        rowcount = 0
926        sql = "BEGIN"
927        try:
928            if not self._dbcnx._tnx:
929                try:
930                    self._src.execute(sql)
931                except DatabaseError:
932                    raise  # database provides error message
933                except Exception:
934                    raise _op_error("Can't start transaction")
935                self._dbcnx._tnx = True
936            for parameters in seq_of_parameters:
937                sql = operation
938                sql = self._quoteparams(sql, parameters)
939                rows = self._src.execute(sql)
940                if rows:  # true if not DML
941                    rowcount += rows
942                else:
943                    self.rowcount = -1
944        except DatabaseError:
945            raise  # database provides error message
946        except Error as err:
947            raise _db_error(
948                "Error in '%s': '%s' " % (sql, err), InterfaceError)
949        except Exception as err:
950            raise _op_error("Internal error in '%s': %s" % (sql, err))
951        # then initialize result raw count and description
952        if self._src.resulttype == RESULT_DQL:
953            self._description = True  # fetch on demand
954            self.rowcount = self._src.ntuples
955            self.lastrowid = None
956            if self.build_row_factory:
957                self.row_factory = self.build_row_factory()
958        else:
959            self.rowcount = rowcount
960            self.lastrowid = self._src.oidstatus()
961        # return the cursor object, so you can write statements such as
962        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
963        return self
964
965    def fetchone(self):
966        """Fetch the next row of a query result set."""
967        res = self.fetchmany(1, False)
968        try:
969            return res[0]
970        except IndexError:
971            return None
972
973    def fetchall(self):
974        """Fetch all (remaining) rows of a query result."""
975        return self.fetchmany(-1, False)
976
977    def fetchmany(self, size=None, keep=False):
978        """Fetch the next set of rows of a query result.
979
980        The number of rows to fetch per call is specified by the
981        size parameter. If it is not given, the cursor's arraysize
982        determines the number of rows to be fetched. If you set
983        the keep parameter to true, this is kept as new arraysize.
984        """
985        if size is None:
986            size = self.arraysize
987        if keep:
988            self.arraysize = size
989        try:
990            result = self._src.fetch(size)
991        except DatabaseError:
992            raise
993        except Error as err:
994            raise _db_error(str(err))
995        typecast = self.type_cache.typecast
996        return [self.row_factory([typecast(value, typ)
997            for typ, value in zip(self.coltypes, row)]) for row in result]
998
999    def callproc(self, procname, parameters=None):
1000        """Call a stored database procedure with the given name.
1001
1002        The sequence of parameters must contain one entry for each input
1003        argument that the procedure expects. The result of the call is the
1004        same as this input sequence; replacement of output and input/output
1005        parameters in the return value is currently not supported.
1006
1007        The procedure may also provide a result set as output. These can be
1008        requested through the standard fetch methods of the cursor.
1009        """
1010        n = parameters and len(parameters) or 0
1011        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
1012        self.execute(query, parameters)
1013        return parameters
1014
1015    def copy_from(self, stream, table,
1016            format=None, sep=None, null=None, size=None, columns=None):
1017        """Copy data from an input stream to the specified table.
1018
1019        The input stream can be a file-like object with a read() method or
1020        it can also be an iterable returning a row or multiple rows of input
1021        on each iteration.
1022
1023        The format must be text, csv or binary. The sep option sets the
1024        column separator (delimiter) used in the non binary formats.
1025        The null option sets the textual representation of NULL in the input.
1026
1027        The size option sets the size of the buffer used when reading data
1028        from file-like objects.
1029
1030        The copy operation can be restricted to a subset of columns. If no
1031        columns are specified, all of them will be copied.
1032        """
1033        binary_format = format == 'binary'
1034        try:
1035            read = stream.read
1036        except AttributeError:
1037            if size:
1038                raise ValueError("Size must only be set for file-like objects")
1039            if binary_format:
1040                input_type = bytes
1041                type_name = 'byte strings'
1042            else:
1043                input_type = basestring
1044                type_name = 'strings'
1045
1046            if isinstance(stream, basestring):
1047                if not isinstance(stream, input_type):
1048                    raise ValueError("The input must be %s" % (type_name,))
1049                if not binary_format:
1050                    if isinstance(stream, str):
1051                        if not stream.endswith('\n'):
1052                            stream += '\n'
1053                    else:
1054                        if not stream.endswith(b'\n'):
1055                            stream += b'\n'
1056
1057                def chunks():
1058                    yield stream
1059
1060            elif isinstance(stream, Iterable):
1061
1062                def chunks():
1063                    for chunk in stream:
1064                        if not isinstance(chunk, input_type):
1065                            raise ValueError(
1066                                "Input stream must consist of %s"
1067                                % (type_name,))
1068                        if isinstance(chunk, str):
1069                            if not chunk.endswith('\n'):
1070                                chunk += '\n'
1071                        else:
1072                            if not chunk.endswith(b'\n'):
1073                                chunk += b'\n'
1074                        yield chunk
1075
1076            else:
1077                raise TypeError("Need an input stream to copy from")
1078        else:
1079            if size is None:
1080                size = 8192
1081            elif not isinstance(size, int):
1082                raise TypeError("The size option must be an integer")
1083            if size > 0:
1084
1085                def chunks():
1086                    while True:
1087                        buffer = read(size)
1088                        yield buffer
1089                        if not buffer or len(buffer) < size:
1090                            break
1091
1092            else:
1093
1094                def chunks():
1095                    yield read()
1096
1097        if not table or not isinstance(table, basestring):
1098            raise TypeError("Need a table to copy to")
1099        if table.lower().startswith('select'):
1100                raise ValueError("Must specify a table, not a query")
1101        else:
1102            table = '"%s"' % (table,)
1103        operation = ['copy %s' % (table,)]
1104        options = []
1105        params = []
1106        if format is not None:
1107            if not isinstance(format, basestring):
1108                raise TypeError("The frmat option must be be a string")
1109            if format not in ('text', 'csv', 'binary'):
1110                raise ValueError("Invalid format")
1111            options.append('format %s' % (format,))
1112        if sep is not None:
1113            if not isinstance(sep, basestring):
1114                raise TypeError("The sep option must be a string")
1115            if format == 'binary':
1116                raise ValueError(
1117                    "The sep option is not allowed with binary format")
1118            if len(sep) != 1:
1119                raise ValueError(
1120                    "The sep option must be a single one-byte character")
1121            options.append('delimiter %s')
1122            params.append(sep)
1123        if null is not None:
1124            if not isinstance(null, basestring):
1125                raise TypeError("The null option must be a string")
1126            options.append('null %s')
1127            params.append(null)
1128        if columns:
1129            if not isinstance(columns, basestring):
1130                columns = ','.join('"%s"' % (col,) for col in columns)
1131            operation.append('(%s)' % (columns,))
1132        operation.append("from stdin")
1133        if options:
1134            operation.append('(%s)' % (','.join(options),))
1135        operation = ' '.join(operation)
1136
1137        putdata = self._src.putdata
1138        self.execute(operation, params)
1139
1140        try:
1141            for chunk in chunks():
1142                putdata(chunk)
1143        except BaseException as error:
1144            self.rowcount = -1
1145            # the following call will re-raise the error
1146            putdata(error)
1147        else:
1148            self.rowcount = putdata(None)
1149
1150        # return the cursor object, so you can chain operations
1151        return self
1152
1153    def copy_to(self, stream, table,
1154            format=None, sep=None, null=None, decode=None, columns=None):
1155        """Copy data from the specified table to an output stream.
1156
1157        The output stream can be a file-like object with a write() method or
1158        it can also be None, in which case the method will return a generator
1159        yielding a row on each iteration.
1160
1161        Output will be returned as byte strings unless you set decode to true.
1162
1163        Note that you can also use a select query instead of the table name.
1164
1165        The format must be text, csv or binary. The sep option sets the
1166        column separator (delimiter) used in the non binary formats.
1167        The null option sets the textual representation of NULL in the output.
1168
1169        The copy operation can be restricted to a subset of columns. If no
1170        columns are specified, all of them will be copied.
1171        """
1172        binary_format = format == 'binary'
1173        if stream is not None:
1174            try:
1175                write = stream.write
1176            except AttributeError:
1177                raise TypeError("Need an output stream to copy to")
1178        if not table or not isinstance(table, basestring):
1179            raise TypeError("Need a table to copy to")
1180        if table.lower().startswith('select'):
1181            if columns:
1182                raise ValueError("Columns must be specified in the query")
1183            table = '(%s)' % (table,)
1184        else:
1185            table = '"%s"' % (table,)
1186        operation = ['copy %s' % (table,)]
1187        options = []
1188        params = []
1189        if format is not None:
1190            if not isinstance(format, basestring):
1191                raise TypeError("The format option must be a string")
1192            if format not in ('text', 'csv', 'binary'):
1193                raise ValueError("Invalid format")
1194            options.append('format %s' % (format,))
1195        if sep is not None:
1196            if not isinstance(sep, basestring):
1197                raise TypeError("The sep option must be a string")
1198            if binary_format:
1199                raise ValueError(
1200                    "The sep option is not allowed with binary format")
1201            if len(sep) != 1:
1202                raise ValueError(
1203                    "The sep option must be a single one-byte character")
1204            options.append('delimiter %s')
1205            params.append(sep)
1206        if null is not None:
1207            if not isinstance(null, basestring):
1208                raise TypeError("The null option must be a string")
1209            options.append('null %s')
1210            params.append(null)
1211        if decode is None:
1212            if format == 'binary':
1213                decode = False
1214            else:
1215                decode = str is unicode
1216        else:
1217            if not isinstance(decode, (int, bool)):
1218                raise TypeError("The decode option must be a boolean")
1219            if decode and binary_format:
1220                raise ValueError(
1221                    "The decode option is not allowed with binary format")
1222        if columns:
1223            if not isinstance(columns, basestring):
1224                columns = ','.join('"%s"' % (col,) for col in columns)
1225            operation.append('(%s)' % (columns,))
1226
1227        operation.append("to stdout")
1228        if options:
1229            operation.append('(%s)' % (','.join(options),))
1230        operation = ' '.join(operation)
1231
1232        getdata = self._src.getdata
1233        self.execute(operation, params)
1234
1235        def copy():
1236            self.rowcount = 0
1237            while True:
1238                row = getdata(decode)
1239                if isinstance(row, int):
1240                    if self.rowcount != row:
1241                        self.rowcount = row
1242                    break
1243                self.rowcount += 1
1244                yield row
1245
1246        if stream is None:
1247            # no input stream, return the generator
1248            return copy()
1249
1250        # write the rows to the file-like input stream
1251        for row in copy():
1252            write(row)
1253
1254        # return the cursor object, so you can chain operations
1255        return self
1256
1257    def __next__(self):
1258        """Return the next row (support for the iteration protocol)."""
1259        res = self.fetchone()
1260        if res is None:
1261            raise StopIteration
1262        return res
1263
1264    # Note that since Python 2.6 the iterator protocol uses __next()__
1265    # instead of next(), we keep it only for backward compatibility of pgdb.
1266    next = __next__
1267
1268    @staticmethod
1269    def nextset():
1270        """Not supported."""
1271        raise NotSupportedError("The nextset() method is not supported")
1272
1273    @staticmethod
1274    def setinputsizes(sizes):
1275        """Not supported."""
1276        pass  # unsupported, but silently passed
1277
1278    @staticmethod
1279    def setoutputsize(size, column=0):
1280        """Not supported."""
1281        pass  # unsupported, but silently passed
1282
1283    @staticmethod
1284    def row_factory(row):
1285        """Process rows before they are returned.
1286
1287        You can overwrite this statically with a custom row factory, or
1288        you can build a row factory dynamically with build_row_factory().
1289
1290        For example, you can create a Cursor class that returns rows as
1291        Python dictionaries like this:
1292
1293            class DictCursor(pgdb.Cursor):
1294
1295                def row_factory(self, row):
1296                    return {desc[0]: value
1297                        for desc, value in zip(self.description, row)}
1298
1299            cur = DictCursor(con)  # get one DictCursor instance or
1300            con.cursor_type = DictCursor  # always use DictCursor instances
1301        """
1302        raise NotImplementedError
1303
1304    def build_row_factory(self):
1305        """Build a row factory based on the current description.
1306
1307        This implementation builds a row factory for creating named tuples.
1308        You can overwrite this method if you want to dynamically create
1309        different row factories whenever the column description changes.
1310        """
1311        colnames = self.colnames
1312        if colnames:
1313            try:
1314                try:
1315                    return namedtuple('Row', colnames, rename=True)._make
1316                except TypeError:  # Python 2.6 and 3.0 do not support rename
1317                    colnames = [v if v.isalnum() else 'column_%d' % (n,)
1318                             for n, v in enumerate(colnames)]
1319                    return namedtuple('Row', colnames)._make
1320            except ValueError:  # there is still a problem with the field names
1321                colnames = ['column_%d' % (n,) for n in range(len(colnames))]
1322                return namedtuple('Row', colnames)._make
1323
1324
1325CursorDescription = namedtuple('CursorDescription',
1326    ['name', 'type_code', 'display_size', 'internal_size',
1327     'precision', 'scale', 'null_ok'])
1328
1329
1330### Connection Objects
1331
1332class Connection(object):
1333    """Connection object."""
1334
1335    # expose the exceptions as attributes on the connection object
1336    Error = Error
1337    Warning = Warning
1338    InterfaceError = InterfaceError
1339    DatabaseError = DatabaseError
1340    InternalError = InternalError
1341    OperationalError = OperationalError
1342    ProgrammingError = ProgrammingError
1343    IntegrityError = IntegrityError
1344    DataError = DataError
1345    NotSupportedError = NotSupportedError
1346
1347    def __init__(self, cnx):
1348        """Create a database connection object."""
1349        self._cnx = cnx  # connection
1350        self._tnx = False  # transaction state
1351        self.type_cache = TypeCache(cnx)
1352        self.cursor_type = Cursor
1353        try:
1354            self._cnx.source()
1355        except Exception:
1356            raise _op_error("Invalid connection")
1357
1358    def __enter__(self):
1359        """Enter the runtime context for the connection object.
1360
1361        The runtime context can be used for running transactions.
1362        """
1363        return self
1364
1365    def __exit__(self, et, ev, tb):
1366        """Exit the runtime context for the connection object.
1367
1368        This does not close the connection, but it ends a transaction.
1369        """
1370        if et is None and ev is None and tb is None:
1371            self.commit()
1372        else:
1373            self.rollback()
1374
1375    def close(self):
1376        """Close the connection object."""
1377        if self._cnx:
1378            if self._tnx:
1379                try:
1380                    self.rollback()
1381                except DatabaseError:
1382                    pass
1383            self._cnx.close()
1384            self._cnx = None
1385        else:
1386            raise _op_error("Connection has been closed")
1387
1388    @property
1389    def closed(self):
1390        """Check whether the connection has been closed or is broken."""
1391        try:
1392            return not self._cnx or self._cnx.status != 1
1393        except TypeError:
1394            return True
1395
1396    def commit(self):
1397        """Commit any pending transaction to the database."""
1398        if self._cnx:
1399            if self._tnx:
1400                self._tnx = False
1401                try:
1402                    self._cnx.source().execute("COMMIT")
1403                except DatabaseError:
1404                    raise
1405                except Exception:
1406                    raise _op_error("Can't commit")
1407        else:
1408            raise _op_error("Connection has been closed")
1409
1410    def rollback(self):
1411        """Roll back to the start of any pending transaction."""
1412        if self._cnx:
1413            if self._tnx:
1414                self._tnx = False
1415                try:
1416                    self._cnx.source().execute("ROLLBACK")
1417                except DatabaseError:
1418                    raise
1419                except Exception:
1420                    raise _op_error("Can't rollback")
1421        else:
1422            raise _op_error("Connection has been closed")
1423
1424    def cursor(self):
1425        """Return a new cursor object using the connection."""
1426        if self._cnx:
1427            try:
1428                return self.cursor_type(self)
1429            except Exception:
1430                raise _op_error("Invalid connection")
1431        else:
1432            raise _op_error("Connection has been closed")
1433
1434    if shortcutmethods:  # otherwise do not implement and document this
1435
1436        def execute(self, operation, params=None):
1437            """Shortcut method to run an operation on an implicit cursor."""
1438            cursor = self.cursor()
1439            cursor.execute(operation, params)
1440            return cursor
1441
1442        def executemany(self, operation, param_seq):
1443            """Shortcut method to run an operation against a sequence."""
1444            cursor = self.cursor()
1445            cursor.executemany(operation, param_seq)
1446            return cursor
1447
1448
1449### Module Interface
1450
1451_connect = connect
1452
1453def connect(dsn=None,
1454        user=None, password=None,
1455        host=None, database=None, **kwargs):
1456    """Connect to a database."""
1457    # first get params from DSN
1458    dbport = -1
1459    dbhost = ""
1460    dbname = ""
1461    dbuser = ""
1462    dbpasswd = ""
1463    dbopt = ""
1464    try:
1465        params = dsn.split(":")
1466        dbhost = params[0]
1467        dbname = params[1]
1468        dbuser = params[2]
1469        dbpasswd = params[3]
1470        dbopt = params[4]
1471    except (AttributeError, IndexError, TypeError):
1472        pass
1473
1474    # override if necessary
1475    if user is not None:
1476        dbuser = user
1477    if password is not None:
1478        dbpasswd = password
1479    if database is not None:
1480        dbname = database
1481    if host is not None:
1482        try:
1483            params = host.split(":")
1484            dbhost = params[0]
1485            dbport = int(params[1])
1486        except (AttributeError, IndexError, TypeError, ValueError):
1487            pass
1488
1489    # empty host is localhost
1490    if dbhost == "":
1491        dbhost = None
1492    if dbuser == "":
1493        dbuser = None
1494
1495    # pass keyword arguments as connection info string
1496    if kwargs:
1497        kwargs = list(kwargs.items())
1498        if '=' in dbname:
1499            dbname = [dbname]
1500        else:
1501            kwargs.insert(0, ('dbname', dbname))
1502            dbname = []
1503        for kw, value in kwargs:
1504            value = str(value)
1505            if not value or ' ' in value:
1506                value = "'%s'" % (value.replace(
1507                    "'", "\\'").replace('\\', '\\\\'),)
1508            dbname.append('%s=%s' % (kw, value))
1509        dbname = ' '.join(dbname)
1510
1511    # open the connection
1512    cnx = _connect(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd)
1513    return Connection(cnx)
1514
1515
1516### Types Handling
1517
1518class Type(frozenset):
1519    """Type class for a couple of PostgreSQL data types.
1520
1521    PostgreSQL is object-oriented: types are dynamic.
1522    We must thus use type names as internal type codes.
1523    """
1524
1525    def __new__(cls, values):
1526        if isinstance(values, basestring):
1527            values = values.split()
1528        return super(Type, cls).__new__(cls, values)
1529
1530    def __eq__(self, other):
1531        if isinstance(other, basestring):
1532            if other.startswith('_'):
1533                other = other[1:]
1534            return other in self
1535        else:
1536            return super(Type, self).__eq__(other)
1537
1538    def __ne__(self, other):
1539        if isinstance(other, basestring):
1540            if other.startswith('_'):
1541                other = other[1:]
1542            return other not in self
1543        else:
1544            return super(Type, self).__ne__(other)
1545
1546
1547class ArrayType:
1548    """Type class for PostgreSQL array types."""
1549
1550    def __eq__(self, other):
1551        if isinstance(other, basestring):
1552            return other.startswith('_')
1553        else:
1554            return isinstance(other, ArrayType)
1555
1556    def __ne__(self, other):
1557        if isinstance(other, basestring):
1558            return not other.startswith('_')
1559        else:
1560            return not isinstance(other, ArrayType)
1561
1562
1563class RecordType:
1564    """Type class for PostgreSQL record types."""
1565
1566    def __eq__(self, other):
1567        if isinstance(other, TypeCode):
1568            return other.type == 'c'
1569        elif isinstance(other, basestring):
1570            return other == 'record'
1571        else:
1572            return isinstance(other, RecordType)
1573
1574    def __ne__(self, other):
1575        if isinstance(other, TypeCode):
1576            return other.type != 'c'
1577        elif isinstance(other, basestring):
1578            return other != 'record'
1579        else:
1580            return not isinstance(other, RecordType)
1581
1582
1583# Mandatory type objects defined by DB-API 2 specs:
1584
1585STRING = Type('char bpchar name text varchar')
1586BINARY = Type('bytea')
1587NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1588DATETIME = Type('date time timetz timestamp timestamptz interval'
1589    ' abstime reltime')  # these are very old
1590ROWID = Type('oid')
1591
1592
1593# Additional type objects (more specific):
1594
1595BOOL = Type('bool')
1596SMALLINT = Type('int2')
1597INTEGER = Type('int2 int4 int8 serial')
1598LONG = Type('int8')
1599FLOAT = Type('float4 float8')
1600NUMERIC = Type('numeric')
1601MONEY = Type('money')
1602DATE = Type('date')
1603TIME = Type('time timetz')
1604TIMESTAMP = Type('timestamp timestamptz')
1605INTERVAL = Type('interval')
1606UUID = Type('uuid')
1607HSTORE = Type('hstore')
1608JSON = Type('json jsonb')
1609
1610# Type object for arrays (also equate to their base types):
1611
1612ARRAY = ArrayType()
1613
1614# Type object for records (encompassing all composite types):
1615
1616RECORD = RecordType()
1617
1618
1619# Mandatory type helpers defined by DB-API 2 specs:
1620
1621def Date(year, month, day):
1622    """Construct an object holding a date value."""
1623    return date(year, month, day)
1624
1625
1626def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None):
1627    """Construct an object holding a time value."""
1628    return time(hour, minute, second, microsecond, tzinfo)
1629
1630
1631def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0,
1632        tzinfo=None):
1633    """Construct an object holding a time stamp value."""
1634    return datetime(year, month, day, hour, minute, second, microsecond, tzinfo)
1635
1636
1637def DateFromTicks(ticks):
1638    """Construct an object holding a date value from the given ticks value."""
1639    return Date(*localtime(ticks)[:3])
1640
1641
1642def TimeFromTicks(ticks):
1643    """Construct an object holding a time value from the given ticks value."""
1644    return Time(*localtime(ticks)[3:6])
1645
1646
1647def TimestampFromTicks(ticks):
1648    """Construct an object holding a time stamp from the given ticks value."""
1649    return Timestamp(*localtime(ticks)[:6])
1650
1651
1652class Binary(bytes):
1653    """Construct an object capable of holding a binary (long) string value."""
1654
1655
1656# Additional type helpers for PyGreSQL:
1657
1658def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0):
1659    """Construct an object holding a time inverval value."""
1660    return timedelta(days, hours=hours, minutes=minutes, seconds=seconds,
1661        microseconds=microseconds)
1662
1663
1664Uuid = Uuid  # Construct an object holding a UUID value
1665
1666
1667class Hstore(dict):
1668    """Wrapper class for marking hstore values."""
1669
1670    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
1671    _re_escape = regex(r'(["\\])')
1672
1673    @classmethod
1674    def _quote(cls, s):
1675        if s is None:
1676            return 'NULL'
1677        if not s:
1678            return '""'
1679        quote = cls._re_quote.search(s)
1680        s = cls._re_escape.sub(r'\\\1', s)
1681        if quote:
1682            s = '"%s"' % (s,)
1683        return s
1684
1685    def __str__(self):
1686        q = self._quote
1687        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
1688
1689
1690class Json:
1691    """Construct a wrapper for holding an object serializable to JSON."""
1692
1693    def __init__(self, obj, encode=None):
1694        self.obj = obj
1695        self.encode = encode or jsonencode
1696
1697    def __str__(self):
1698        obj = self.obj
1699        if isinstance(obj, basestring):
1700            return obj
1701        return self.encode(obj)
1702
1703
1704class Literal:
1705    """Construct a wrapper for holding a literal SQL string."""
1706
1707    def __init__(self, sql):
1708        self.sql = sql
1709
1710    def __str__(self):
1711        return self.sql
1712
1713    __pg_repr__ = __str__
1714
1715# If run as script, print some information:
1716
1717if __name__ == '__main__':
1718    print('PyGreSQL version', version)
1719    print('')
1720    print(__doc__)
Note: See TracBrowser for help on using the repository browser.