source: trunk/pgdb.py @ 881

Last change on this file since 881 was 881, checked in by cito, 3 years ago

Add memory leak test for pgdb connections

  • Property svn:keywords set to Id
File size: 56.3 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 881 2016-07-23 19:41:16Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function, division
67
68from _pg import *
69
70__version__ = version
71
72from datetime import date, time, datetime, timedelta, tzinfo
73from time import localtime
74from decimal import Decimal
75from uuid import UUID as Uuid
76from math import isnan, isinf
77from collections import namedtuple
78from functools import partial
79from re import compile as regex
80from json import loads as jsondecode, dumps as jsonencode
81
82try:
83    long
84except NameError:  # Python >= 3.0
85    long = int
86
87try:
88    unicode
89except NameError:  # Python >= 3.0
90    unicode = str
91
92try:
93    basestring
94except NameError:  # Python >= 3.0
95    basestring = (str, bytes)
96
97from collections import Iterable
98
99
100### Module Constants
101
102# compliant with DB API 2.0
103apilevel = '2.0'
104
105# module may be shared, but not connections
106threadsafety = 1
107
108# this module use extended python format codes
109paramstyle = 'pyformat'
110
111# shortcut methods have been excluded from DB API 2 and
112# are not recommended by the DB SIG, but they can be handy
113shortcutmethods = 1
114
115
116### Internal Type Handling
117
118try:
119    from inspect import signature
120except ImportError:  # Python < 3.3
121    from inspect import getargspec
122
123    def get_args(func):
124        return getargspec(func).args
125else:
126
127    def get_args(func):
128        return list(signature(func).parameters)
129
130try:
131    from datetime import timezone
132except ImportError:  # Python < 3.2
133
134    class timezone(tzinfo):
135        """Simple timezone implementation."""
136
137        def __init__(self, offset, name=None):
138            self.offset = offset
139            if not name:
140                minutes = self.offset.days * 1440 + self.offset.seconds // 60
141                if minutes < 0:
142                    hours, minutes = divmod(-minutes, 60)
143                    hours = -hours
144                else:
145                    hours, minutes = divmod(minutes, 60)
146                name = 'UTC%+03d:%02d' % (hours, minutes)
147            self.name = name
148
149        def utcoffset(self, dt):
150            return self.offset
151
152        def tzname(self, dt):
153            return self.name
154
155        def dst(self, dt):
156            return None
157
158    timezone.utc = timezone(timedelta(0), 'UTC')
159
160    _has_timezone = False
161else:
162    _has_timezone = True
163
164# time zones used in Postgres timestamptz output
165_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
166    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
167    UCT='+0000', UTC='+0000', WET='+0000')
168
169
170def _timezone_as_offset(tz):
171    if tz.startswith(('+', '-')):
172        if len(tz) < 5:
173            return tz + '00'
174        return tz.replace(':', '')
175    return _timezones.get(tz, '+0000')
176
177
178def _get_timezone(tz):
179    tz = _timezone_as_offset(tz)
180    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
181    if tz[0] == '-':
182        minutes = -minutes
183    return timezone(timedelta(minutes=minutes), tz)
184
185
186def decimal_type(decimal_type=None):
187    """Get or set global type to be used for decimal values.
188
189    Note that connections cache cast functions. To be sure a global change
190    is picked up by a running connection, call con.type_cache.reset_typecast().
191    """
192    global Decimal
193    if decimal_type is not None:
194        Decimal = decimal_type
195        set_typecast('numeric', decimal_type)
196    return Decimal
197
198
199def cast_bool(value):
200    """Cast boolean value in database format to bool."""
201    if value:
202        return value[0] in ('t', 'T')
203
204
205def cast_money(value):
206    """Cast money value in database format to Decimal."""
207    if value:
208        value = value.replace('(', '-')
209        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
210
211
212def cast_int2vector(value):
213    """Cast an int2vector value."""
214    return [int(v) for v in value.split()]
215
216
217def cast_date(value, connection):
218    """Cast a date value."""
219    # The output format depends on the server setting DateStyle.  The default
220    # setting ISO and the setting for German are actually unambiguous.  The
221    # order of days and months in the other two settings is however ambiguous,
222    # so at least here we need to consult the setting to properly parse values.
223    if value == '-infinity':
224        return date.min
225    if value == 'infinity':
226        return date.max
227    value = value.split()
228    if value[-1] == 'BC':
229        return date.min
230    value = value[0]
231    if len(value) > 10:
232        return date.max
233    fmt = connection.date_format()
234    return datetime.strptime(value, fmt).date()
235
236
237def cast_time(value):
238    """Cast a time value."""
239    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
240    return datetime.strptime(value, fmt).time()
241
242
243_re_timezone = regex('(.*)([+-].*)')
244
245
246def cast_timetz(value):
247    """Cast a timetz value."""
248    tz = _re_timezone.match(value)
249    if tz:
250        value, tz = tz.groups()
251    else:
252        tz = '+0000'
253    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
254    if _has_timezone:
255        value += _timezone_as_offset(tz)
256        fmt += '%z'
257        return datetime.strptime(value, fmt).timetz()
258    return datetime.strptime(value, fmt).timetz().replace(
259        tzinfo=_get_timezone(tz))
260
261
262def cast_timestamp(value, connection):
263    """Cast a timestamp value."""
264    if value == '-infinity':
265        return datetime.min
266    if value == 'infinity':
267        return datetime.max
268    value = value.split()
269    if value[-1] == 'BC':
270        return datetime.min
271    fmt = connection.date_format()
272    if fmt.endswith('-%Y') and len(value) > 2:
273        value = value[1:5]
274        if len(value[3]) > 4:
275            return datetime.max
276        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
277            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
278    else:
279        if len(value[0]) > 10:
280            return datetime.max
281        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
282    return datetime.strptime(' '.join(value), ' '.join(fmt))
283
284
285def cast_timestamptz(value, connection):
286    """Cast a timestamptz value."""
287    if value == '-infinity':
288        return datetime.min
289    if value == 'infinity':
290        return datetime.max
291    value = value.split()
292    if value[-1] == 'BC':
293        return datetime.min
294    fmt = connection.date_format()
295    if fmt.endswith('-%Y') and len(value) > 2:
296        value = value[1:]
297        if len(value[3]) > 4:
298            return datetime.max
299        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
300            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
301        value, tz = value[:-1], value[-1]
302    else:
303        if fmt.startswith('%Y-'):
304            tz = _re_timezone.match(value[1])
305            if tz:
306                value[1], tz = tz.groups()
307            else:
308                tz = '+0000'
309        else:
310            value, tz = value[:-1], value[-1]
311        if len(value[0]) > 10:
312            return datetime.max
313        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
314    if _has_timezone:
315        value.append(_timezone_as_offset(tz))
316        fmt.append('%z')
317        return datetime.strptime(' '.join(value), ' '.join(fmt))
318    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
319        tzinfo=_get_timezone(tz))
320
321
322_re_interval_sql_standard = regex(
323    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
324    '(?:([+-]?[0-9]+)(?!:) ?)?'
325    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
326
327_re_interval_postgres = regex(
328    '(?:([+-]?[0-9]+) ?years? ?)?'
329    '(?:([+-]?[0-9]+) ?mons? ?)?'
330    '(?:([+-]?[0-9]+) ?days? ?)?'
331    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
332
333_re_interval_postgres_verbose = regex(
334    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
335    '(?:([+-]?[0-9]+) ?mons? ?)?'
336    '(?:([+-]?[0-9]+) ?days? ?)?'
337    '(?:([+-]?[0-9]+) ?hours? ?)?'
338    '(?:([+-]?[0-9]+) ?mins? ?)?'
339    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
340
341_re_interval_iso_8601 = regex(
342    'P(?:([+-]?[0-9]+)Y)?'
343    '(?:([+-]?[0-9]+)M)?'
344    '(?:([+-]?[0-9]+)D)?'
345    '(?:T(?:([+-]?[0-9]+)H)?'
346    '(?:([+-]?[0-9]+)M)?'
347    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
348
349
350def cast_interval(value):
351    """Cast an interval value."""
352    # The output format depends on the server setting IntervalStyle, but it's
353    # not necessary to consult this setting to parse it.  It's faster to just
354    # check all possible formats, and there is no ambiguity here.
355    m = _re_interval_iso_8601.match(value)
356    if m:
357        m = [d or '0' for d in m.groups()]
358        secs_ago = m.pop(5) == '-'
359        m = [int(d) for d in m]
360        years, mons, days, hours, mins, secs, usecs = m
361        if secs_ago:
362            secs = -secs
363            usecs = -usecs
364    else:
365        m = _re_interval_postgres_verbose.match(value)
366        if m:
367            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
368            secs_ago = m.pop(5) == '-'
369            m = [-int(d) for d in m] if ago else [int(d) for d in m]
370            years, mons, days, hours, mins, secs, usecs = m
371            if secs_ago:
372                secs = - secs
373                usecs = -usecs
374        else:
375            m = _re_interval_postgres.match(value)
376            if m and any(m.groups()):
377                m = [d or '0' for d in m.groups()]
378                hours_ago = m.pop(3) == '-'
379                m = [int(d) for d in m]
380                years, mons, days, hours, mins, secs, usecs = m
381                if hours_ago:
382                    hours = -hours
383                    mins = -mins
384                    secs = -secs
385                    usecs = -usecs
386            else:
387                m = _re_interval_sql_standard.match(value)
388                if m and any(m.groups()):
389                    m = [d or '0' for d in m.groups()]
390                    years_ago = m.pop(0) == '-'
391                    hours_ago = m.pop(3) == '-'
392                    m = [int(d) for d in m]
393                    years, mons, days, hours, mins, secs, usecs = m
394                    if years_ago:
395                        years = -years
396                        mons = -mons
397                    if hours_ago:
398                        hours = -hours
399                        mins = -mins
400                        secs = -secs
401                        usecs = -usecs
402                else:
403                    raise ValueError('Cannot parse interval: %s' % value)
404    days += 365 * years + 30 * mons
405    return timedelta(days=days, hours=hours, minutes=mins,
406        seconds=secs, microseconds=usecs)
407
408
409class Typecasts(dict):
410    """Dictionary mapping database types to typecast functions.
411
412    The cast functions get passed the string representation of a value in
413    the database which they need to convert to a Python object.  The
414    passed string will never be None since NULL values are already be
415    handled before the cast function is called.
416    """
417
418    # the default cast functions
419    # (str functions are ignored but have been added for faster access)
420    defaults = {'char': str, 'bpchar': str, 'name': str,
421        'text': str, 'varchar': str,
422        'bool': cast_bool, 'bytea': unescape_bytea,
423        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
424        'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode,
425        'float4': float, 'float8': float,
426        'numeric': Decimal, 'money': cast_money,
427        'date': cast_date, 'interval': cast_interval,
428        'time': cast_time, 'timetz': cast_timetz,
429        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
430        'int2vector': cast_int2vector, 'uuid': Uuid,
431        'anyarray': cast_array, 'record': cast_record}
432
433    connection = None  # will be set in local connection specific instances
434
435    def __missing__(self, typ):
436        """Create a cast function if it is not cached.
437
438        Note that this class never raises a KeyError,
439        but returns None when no special cast function exists.
440        """
441        if not isinstance(typ, str):
442            raise TypeError('Invalid type: %s' % typ)
443        cast = self.defaults.get(typ)
444        if cast:
445            # store default for faster access
446            cast = self._add_connection(cast)
447            self[typ] = cast
448        elif typ.startswith('_'):
449            # create array cast
450            base_cast = self[typ[1:]]
451            cast = self.create_array_cast(base_cast)
452            if base_cast:
453                # store only if base type exists
454                self[typ] = cast
455        return cast
456
457    @staticmethod
458    def _needs_connection(func):
459        """Check if a typecast function needs a connection argument."""
460        try:
461            args = get_args(func)
462        except (TypeError, ValueError):
463            return False
464        else:
465            return 'connection' in args[1:]
466
467    def _add_connection(self, cast):
468        """Add a connection argument to the typecast function if necessary."""
469        if not self.connection or not self._needs_connection(cast):
470            return cast
471        return partial(cast, connection=self.connection)
472
473    def get(self, typ, default=None):
474        """Get the typecast function for the given database type."""
475        return self[typ] or default
476
477    def set(self, typ, cast):
478        """Set a typecast function for the specified database type(s)."""
479        if isinstance(typ, basestring):
480            typ = [typ]
481        if cast is None:
482            for t in typ:
483                self.pop(t, None)
484                self.pop('_%s' % t, None)
485        else:
486            if not callable(cast):
487                raise TypeError("Cast parameter must be callable")
488            for t in typ:
489                self[t] = self._add_connection(cast)
490                self.pop('_%s' % t, None)
491
492    def reset(self, typ=None):
493        """Reset the typecasts for the specified type(s) to their defaults.
494
495        When no type is specified, all typecasts will be reset.
496        """
497        defaults = self.defaults
498        if typ is None:
499            self.clear()
500            self.update(defaults)
501        else:
502            if isinstance(typ, basestring):
503                typ = [typ]
504            for t in typ:
505                cast = defaults.get(t)
506                if cast:
507                    self[t] = self._add_connection(cast)
508                    t = '_%s' % t
509                    cast = defaults.get(t)
510                    if cast:
511                        self[t] = self._add_connection(cast)
512                    else:
513                        self.pop(t, None)
514                else:
515                    self.pop(t, None)
516                    self.pop('_%s' % t, None)
517
518    def create_array_cast(self, basecast):
519        """Create an array typecast for the given base cast."""
520        def cast(v):
521            return cast_array(v, basecast)
522        return cast
523
524    def create_record_cast(self, name, fields, casts):
525        """Create a named record typecast for the given fields and casts."""
526        record = namedtuple(name, fields)
527        def cast(v):
528            return record(*cast_record(v, casts))
529        return cast
530
531
532_typecasts = Typecasts()  # this is the global typecast dictionary
533
534
535def get_typecast(typ):
536    """Get the global typecast function for the given database type(s)."""
537    return _typecasts.get(typ)
538
539
540def set_typecast(typ, cast):
541    """Set a global typecast function for the given database type(s).
542
543    Note that connections cache cast functions. To be sure a global change
544    is picked up by a running connection, call con.type_cache.reset_typecast().
545    """
546    _typecasts.set(typ, cast)
547
548
549def reset_typecast(typ=None):
550    """Reset the global typecasts for the given type(s) to their default.
551
552    When no type is specified, all typecasts will be reset.
553
554    Note that connections cache cast functions. To be sure a global change
555    is picked up by a running connection, call con.type_cache.reset_typecast().
556    """
557    _typecasts.reset(typ)
558
559
560class LocalTypecasts(Typecasts):
561    """Map typecasts, including local composite types, to cast functions."""
562
563    defaults = _typecasts
564
565    connection = None  # will be set in a connection specific instance
566
567    def __missing__(self, typ):
568        """Create a cast function if it is not cached."""
569        if typ.startswith('_'):
570            base_cast = self[typ[1:]]
571            cast = self.create_array_cast(base_cast)
572            if base_cast:
573                self[typ] = cast
574        else:
575            cast = self.defaults.get(typ)
576            if cast:
577                cast = self._add_connection(cast)
578                self[typ] = cast
579            else:
580                fields = self.get_fields(typ)
581                if fields:
582                    casts = [self[field.type] for field in fields]
583                    fields = [field.name for field in fields]
584                    cast = self.create_record_cast(typ, fields, casts)
585                    self[typ] = cast
586        return cast
587
588    def get_fields(self, typ):
589        """Return the fields for the given record type.
590
591        This method will be replaced with a method that looks up the fields
592        using the type cache of the connection.
593        """
594        return []
595
596
597class TypeCode(str):
598    """Class representing the type_code used by the DB-API 2.0.
599
600    TypeCode objects are strings equal to the PostgreSQL type name,
601    but carry some additional information.
602    """
603
604    @classmethod
605    def create(cls, oid, name, len, type, category, delim, relid):
606        """Create a type code for a PostgreSQL data type."""
607        self = cls(name)
608        self.oid = oid
609        self.len = len
610        self.type = type
611        self.category = category
612        self.delim = delim
613        self.relid = relid
614        return self
615
616FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
617
618
619class TypeCache(dict):
620    """Cache for database types.
621
622    This cache maps type OIDs and names to TypeCode strings containing
623    important information on the associated database type.
624    """
625
626    def __init__(self, cnx):
627        """Initialize type cache for connection."""
628        super(TypeCache, self).__init__()
629        self._escape_string = cnx.escape_string
630        self._src = cnx.source()
631        self._typecasts = LocalTypecasts()
632        self._typecasts.get_fields = self.get_fields
633        self._typecasts.connection = cnx
634        if cnx.server_version < 80400:
635            # older remote databases (not officially supported)
636            self._query_pg_type = ("SELECT oid, typname,"
637                " typlen, typtype, null as typcategory, typdelim, typrelid"
638                " FROM pg_type WHERE oid=%s")
639        else:
640            self._query_pg_type = ("SELECT oid, typname,"
641                " typlen, typtype, typcategory, typdelim, typrelid"
642                " FROM pg_type WHERE oid=%s")
643
644    def __missing__(self, key):
645        """Get the type info from the database if it is not cached."""
646        if isinstance(key, int):
647            oid = key
648        else:
649            if '.' not in key and '"' not in key:
650                key = '"%s"' % (key,)
651            oid = "'%s'::regtype" % (self._escape_string(key),)
652        try:
653            self._src.execute(self._query_pg_type % (oid,))
654        except ProgrammingError:
655            res = None
656        else:
657            res = self._src.fetch(1)
658        if not res:
659            raise KeyError('Type %s could not be found' % (key,))
660        res = res[0]
661        type_code = TypeCode.create(int(res[0]), res[1],
662            int(res[2]), res[3], res[4], res[5], int(res[6]))
663        self[type_code.oid] = self[str(type_code)] = type_code
664        return type_code
665
666    def get(self, key, default=None):
667        """Get the type even if it is not cached."""
668        try:
669            return self[key]
670        except KeyError:
671            return default
672
673    def get_fields(self, typ):
674        """Get the names and types of the fields of composite types."""
675        if not isinstance(typ, TypeCode):
676            typ = self.get(typ)
677            if not typ:
678                return None
679        if not typ.relid:
680            return None  # this type is not composite
681        self._src.execute("SELECT attname, atttypid"
682            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
683            " AND NOT attisdropped ORDER BY attnum" % (typ.relid,))
684        return [FieldInfo(name, self.get(int(oid)))
685            for name, oid in self._src.fetch(-1)]
686
687    def get_typecast(self, typ):
688        """Get the typecast function for the given database type."""
689        return self._typecasts.get(typ)
690
691    def set_typecast(self, typ, cast):
692        """Set a typecast function for the specified database type(s)."""
693        self._typecasts.set(typ, cast)
694
695    def reset_typecast(self, typ=None):
696        """Reset the typecast function for the specified database type(s)."""
697        self._typecasts.reset(typ)
698
699    def typecast(self, value, typ):
700        """Cast the given value according to the given database type."""
701        if value is None:
702            # for NULL values, no typecast is necessary
703            return None
704        cast = self.get_typecast(typ)
705        if not cast or cast is str:
706            # no typecast is necessary
707            return value
708        return cast(value)
709
710
711class _quotedict(dict):
712    """Dictionary with auto quoting of its items.
713
714    The quote attribute must be set to the desired quote function.
715    """
716
717    def __getitem__(self, key):
718        return self.quote(super(_quotedict, self).__getitem__(key))
719
720
721### Error messages
722
723def _db_error(msg, cls=DatabaseError):
724    """Return DatabaseError with empty sqlstate attribute."""
725    error = cls(msg)
726    error.sqlstate = None
727    return error
728
729
730def _op_error(msg):
731    """Return OperationalError."""
732    return _db_error(msg, OperationalError)
733
734
735### Cursor Object
736
737class Cursor(object):
738    """Cursor object."""
739
740    def __init__(self, dbcnx):
741        """Create a cursor object for the database connection."""
742        self.connection = self._dbcnx = dbcnx
743        self._cnx = dbcnx._cnx
744        self.type_cache = dbcnx.type_cache
745        self._src = self._cnx.source()
746        # the official attribute for describing the result columns
747        self._description = None
748        if self.row_factory is Cursor.row_factory:
749            # the row factory needs to be determined dynamically
750            self.row_factory = None
751        else:
752            self.build_row_factory = None
753        self.rowcount = -1
754        self.arraysize = 1
755        self.lastrowid = None
756
757    def __iter__(self):
758        """Make cursor compatible to the iteration protocol."""
759        return self
760
761    def __enter__(self):
762        """Enter the runtime context for the cursor object."""
763        return self
764
765    def __exit__(self, et, ev, tb):
766        """Exit the runtime context for the cursor object."""
767        self.close()
768
769    def _quote(self, value):
770        """Quote value depending on its type."""
771        if value is None:
772            return 'NULL'
773        if isinstance(value, (Hstore, Json)):
774            value = str(value)
775        if isinstance(value, basestring):
776            if isinstance(value, Binary):
777                value = self._cnx.escape_bytea(value)
778                if bytes is not str:  # Python >= 3.0
779                    value = value.decode('ascii')
780            else:
781                value = self._cnx.escape_string(value)
782            return "'%s'" % (value,)
783        if isinstance(value, float):
784            if isinf(value):
785                return "'-Infinity'" if value < 0 else "'Infinity'"
786            if isnan(value):
787                return "'NaN'"
788            return value
789        if isinstance(value, (int, long, Decimal, Literal)):
790            return value
791        if isinstance(value, datetime):
792            if value.tzinfo:
793                return "'%s'::timestamptz" % (value,)
794            return "'%s'::timestamp" % (value,)
795        if isinstance(value, date):
796            return "'%s'::date" % (value,)
797        if isinstance(value, time):
798            if value.tzinfo:
799                return "'%s'::timetz" % (value,)
800            return "'%s'::time" % value
801        if isinstance(value, timedelta):
802            return "'%s'::interval" % (value,)
803        if isinstance(value, Uuid):
804            return "'%s'::uuid" % (value,)
805        if isinstance(value, list):
806            # Quote value as an ARRAY constructor. This is better than using
807            # an array literal because it carries the information that this is
808            # an array and not a string.  One issue with this syntax is that
809            # you need to add an explicit typecast when passing empty arrays.
810            # The ARRAY keyword is actually only necessary at the top level.
811            if not value:  # exception for empty array
812                return "'{}'"
813            q = self._quote
814            try:
815                return 'ARRAY[%s]' % (','.join(str(q(v)) for v in value),)
816            except UnicodeEncodeError:  # Python 2 with non-ascii values
817                return u'ARRAY[%s]' % (','.join(unicode(q(v)) for v in value),)
818        if isinstance(value, tuple):
819            # Quote as a ROW constructor.  This is better than using a record
820            # literal because it carries the information that this is a record
821            # and not a string.  We don't use the keyword ROW in order to make
822            # this usable with the IN syntax as well.  It is only necessary
823            # when the records has a single column which is not really useful.
824            q = self._quote
825            try:
826                return '(%s)' % (','.join(str(q(v)) for v in value),)
827            except UnicodeEncodeError:  # Python 2 with non-ascii values
828                return u'(%s)' % (','.join(unicode(q(v)) for v in value),)
829        try:
830            value = value.__pg_repr__()
831        except AttributeError:
832            raise InterfaceError(
833                'Do not know how to adapt type %s' % (type(value),))
834        if isinstance(value, (tuple, list)):
835            value = self._quote(value)
836        return value
837
838    def _quoteparams(self, string, parameters):
839        """Quote parameters.
840
841        This function works for both mappings and sequences.
842
843        The function should be used even when there are no parameters,
844        so that we have a consistent behavior regarding percent signs.
845        """
846        if not parameters:
847            try:
848                return string % ()  # unescape literal quotes if possible
849            except (TypeError, ValueError):
850                return string  # silently accept unescaped quotes
851        if isinstance(parameters, dict):
852            parameters = _quotedict(parameters)
853            parameters.quote = self._quote
854        else:
855            parameters = tuple(map(self._quote, parameters))
856        return string % parameters
857
858    def _make_description(self, info):
859        """Make the description tuple for the given field info."""
860        name, typ, size, mod = info[1:]
861        type_code = self.type_cache[typ]
862        if mod > 0:
863            mod -= 4
864        if type_code == 'numeric':
865            precision, scale = mod >> 16, mod & 0xffff
866            size = precision
867        else:
868            if not size:
869                size = type_code.size
870            if size == -1:
871                size = mod
872            precision = scale = None
873        return CursorDescription(name, type_code,
874            None, size, precision, scale, None)
875
876    @property
877    def description(self):
878        """Read-only attribute describing the result columns."""
879        descr = self._description
880        if self._description is True:
881            make = self._make_description
882            descr = [make(info) for info in self._src.listinfo()]
883            self._description = descr
884        return descr
885
886    @property
887    def colnames(self):
888        """Unofficial convenience method for getting the column names."""
889        return [d[0] for d in self.description]
890
891    @property
892    def coltypes(self):
893        """Unofficial convenience method for getting the column types."""
894        return [d[1] for d in self.description]
895
896    def close(self):
897        """Close the cursor object."""
898        self._src.close()
899
900    def execute(self, operation, parameters=None):
901        """Prepare and execute a database operation (query or command)."""
902        # The parameters may also be specified as list of tuples to e.g.
903        # insert multiple rows in a single operation, but this kind of
904        # usage is deprecated.  We make several plausibility checks because
905        # tuples can also be passed with the meaning of ROW constructors.
906        if (parameters and isinstance(parameters, list)
907                and len(parameters) > 1
908                and all(isinstance(p, tuple) for p in parameters)
909                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
910            return self.executemany(operation, parameters)
911        else:
912            # not a list of tuples
913            return self.executemany(operation, [parameters])
914
915    def executemany(self, operation, seq_of_parameters):
916        """Prepare operation and execute it against a parameter sequence."""
917        if not seq_of_parameters:
918            # don't do anything without parameters
919            return
920        self._description = None
921        self.rowcount = -1
922        # first try to execute all queries
923        rowcount = 0
924        sql = "BEGIN"
925        try:
926            if not self._dbcnx._tnx:
927                try:
928                    self._src.execute(sql)
929                except DatabaseError:
930                    raise  # database provides error message
931                except Exception:
932                    raise _op_error("Can't start transaction")
933                self._dbcnx._tnx = True
934            for parameters in seq_of_parameters:
935                sql = operation
936                sql = self._quoteparams(sql, parameters)
937                rows = self._src.execute(sql)
938                if rows:  # true if not DML
939                    rowcount += rows
940                else:
941                    self.rowcount = -1
942        except DatabaseError:
943            raise  # database provides error message
944        except Error as err:
945            raise _db_error(
946                "Error in '%s': '%s' " % (sql, err), InterfaceError)
947        except Exception as err:
948            raise _op_error("Internal error in '%s': %s" % (sql, err))
949        # then initialize result raw count and description
950        if self._src.resulttype == RESULT_DQL:
951            self._description = True  # fetch on demand
952            self.rowcount = self._src.ntuples
953            self.lastrowid = None
954            if self.build_row_factory:
955                self.row_factory = self.build_row_factory()
956        else:
957            self.rowcount = rowcount
958            self.lastrowid = self._src.oidstatus()
959        # return the cursor object, so you can write statements such as
960        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
961        return self
962
963    def fetchone(self):
964        """Fetch the next row of a query result set."""
965        res = self.fetchmany(1, False)
966        try:
967            return res[0]
968        except IndexError:
969            return None
970
971    def fetchall(self):
972        """Fetch all (remaining) rows of a query result."""
973        return self.fetchmany(-1, False)
974
975    def fetchmany(self, size=None, keep=False):
976        """Fetch the next set of rows of a query result.
977
978        The number of rows to fetch per call is specified by the
979        size parameter. If it is not given, the cursor's arraysize
980        determines the number of rows to be fetched. If you set
981        the keep parameter to true, this is kept as new arraysize.
982        """
983        if size is None:
984            size = self.arraysize
985        if keep:
986            self.arraysize = size
987        try:
988            result = self._src.fetch(size)
989        except DatabaseError:
990            raise
991        except Error as err:
992            raise _db_error(str(err))
993        typecast = self.type_cache.typecast
994        return [self.row_factory([typecast(value, typ)
995            for typ, value in zip(self.coltypes, row)]) for row in result]
996
997    def callproc(self, procname, parameters=None):
998        """Call a stored database procedure with the given name.
999
1000        The sequence of parameters must contain one entry for each input
1001        argument that the procedure expects. The result of the call is the
1002        same as this input sequence; replacement of output and input/output
1003        parameters in the return value is currently not supported.
1004
1005        The procedure may also provide a result set as output. These can be
1006        requested through the standard fetch methods of the cursor.
1007        """
1008        n = parameters and len(parameters) or 0
1009        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
1010        self.execute(query, parameters)
1011        return parameters
1012
1013    def copy_from(self, stream, table,
1014            format=None, sep=None, null=None, size=None, columns=None):
1015        """Copy data from an input stream to the specified table.
1016
1017        The input stream can be a file-like object with a read() method or
1018        it can also be an iterable returning a row or multiple rows of input
1019        on each iteration.
1020
1021        The format must be text, csv or binary. The sep option sets the
1022        column separator (delimiter) used in the non binary formats.
1023        The null option sets the textual representation of NULL in the input.
1024
1025        The size option sets the size of the buffer used when reading data
1026        from file-like objects.
1027
1028        The copy operation can be restricted to a subset of columns. If no
1029        columns are specified, all of them will be copied.
1030        """
1031        binary_format = format == 'binary'
1032        try:
1033            read = stream.read
1034        except AttributeError:
1035            if size:
1036                raise ValueError("Size must only be set for file-like objects")
1037            if binary_format:
1038                input_type = bytes
1039                type_name = 'byte strings'
1040            else:
1041                input_type = basestring
1042                type_name = 'strings'
1043
1044            if isinstance(stream, basestring):
1045                if not isinstance(stream, input_type):
1046                    raise ValueError("The input must be %s" % (type_name,))
1047                if not binary_format:
1048                    if isinstance(stream, str):
1049                        if not stream.endswith('\n'):
1050                            stream += '\n'
1051                    else:
1052                        if not stream.endswith(b'\n'):
1053                            stream += b'\n'
1054
1055                def chunks():
1056                    yield stream
1057
1058            elif isinstance(stream, Iterable):
1059
1060                def chunks():
1061                    for chunk in stream:
1062                        if not isinstance(chunk, input_type):
1063                            raise ValueError(
1064                                "Input stream must consist of %s"
1065                                % (type_name,))
1066                        if isinstance(chunk, str):
1067                            if not chunk.endswith('\n'):
1068                                chunk += '\n'
1069                        else:
1070                            if not chunk.endswith(b'\n'):
1071                                chunk += b'\n'
1072                        yield chunk
1073
1074            else:
1075                raise TypeError("Need an input stream to copy from")
1076        else:
1077            if size is None:
1078                size = 8192
1079            elif not isinstance(size, int):
1080                raise TypeError("The size option must be an integer")
1081            if size > 0:
1082
1083                def chunks():
1084                    while True:
1085                        buffer = read(size)
1086                        yield buffer
1087                        if not buffer or len(buffer) < size:
1088                            break
1089
1090            else:
1091
1092                def chunks():
1093                    yield read()
1094
1095        if not table or not isinstance(table, basestring):
1096            raise TypeError("Need a table to copy to")
1097        if table.lower().startswith('select'):
1098                raise ValueError("Must specify a table, not a query")
1099        else:
1100            table = '"%s"' % (table,)
1101        operation = ['copy %s' % (table,)]
1102        options = []
1103        params = []
1104        if format is not None:
1105            if not isinstance(format, basestring):
1106                raise TypeError("The frmat option must be be a string")
1107            if format not in ('text', 'csv', 'binary'):
1108                raise ValueError("Invalid format")
1109            options.append('format %s' % (format,))
1110        if sep is not None:
1111            if not isinstance(sep, basestring):
1112                raise TypeError("The sep option must be a string")
1113            if format == 'binary':
1114                raise ValueError(
1115                    "The sep option is not allowed with binary format")
1116            if len(sep) != 1:
1117                raise ValueError(
1118                    "The sep option must be a single one-byte character")
1119            options.append('delimiter %s')
1120            params.append(sep)
1121        if null is not None:
1122            if not isinstance(null, basestring):
1123                raise TypeError("The null option must be a string")
1124            options.append('null %s')
1125            params.append(null)
1126        if columns:
1127            if not isinstance(columns, basestring):
1128                columns = ','.join('"%s"' % (col,) for col in columns)
1129            operation.append('(%s)' % (columns,))
1130        operation.append("from stdin")
1131        if options:
1132            operation.append('(%s)' % (','.join(options),))
1133        operation = ' '.join(operation)
1134
1135        putdata = self._src.putdata
1136        self.execute(operation, params)
1137
1138        try:
1139            for chunk in chunks():
1140                putdata(chunk)
1141        except BaseException as error:
1142            self.rowcount = -1
1143            # the following call will re-raise the error
1144            putdata(error)
1145        else:
1146            self.rowcount = putdata(None)
1147
1148        # return the cursor object, so you can chain operations
1149        return self
1150
1151    def copy_to(self, stream, table,
1152            format=None, sep=None, null=None, decode=None, columns=None):
1153        """Copy data from the specified table to an output stream.
1154
1155        The output stream can be a file-like object with a write() method or
1156        it can also be None, in which case the method will return a generator
1157        yielding a row on each iteration.
1158
1159        Output will be returned as byte strings unless you set decode to true.
1160
1161        Note that you can also use a select query instead of the table name.
1162
1163        The format must be text, csv or binary. The sep option sets the
1164        column separator (delimiter) used in the non binary formats.
1165        The null option sets the textual representation of NULL in the output.
1166
1167        The copy operation can be restricted to a subset of columns. If no
1168        columns are specified, all of them will be copied.
1169        """
1170        binary_format = format == 'binary'
1171        if stream is not None:
1172            try:
1173                write = stream.write
1174            except AttributeError:
1175                raise TypeError("Need an output stream to copy to")
1176        if not table or not isinstance(table, basestring):
1177            raise TypeError("Need a table to copy to")
1178        if table.lower().startswith('select'):
1179            if columns:
1180                raise ValueError("Columns must be specified in the query")
1181            table = '(%s)' % (table,)
1182        else:
1183            table = '"%s"' % (table,)
1184        operation = ['copy %s' % (table,)]
1185        options = []
1186        params = []
1187        if format is not None:
1188            if not isinstance(format, basestring):
1189                raise TypeError("The format option must be a string")
1190            if format not in ('text', 'csv', 'binary'):
1191                raise ValueError("Invalid format")
1192            options.append('format %s' % (format,))
1193        if sep is not None:
1194            if not isinstance(sep, basestring):
1195                raise TypeError("The sep option must be a string")
1196            if binary_format:
1197                raise ValueError(
1198                    "The sep option is not allowed with binary format")
1199            if len(sep) != 1:
1200                raise ValueError(
1201                    "The sep option must be a single one-byte character")
1202            options.append('delimiter %s')
1203            params.append(sep)
1204        if null is not None:
1205            if not isinstance(null, basestring):
1206                raise TypeError("The null option must be a string")
1207            options.append('null %s')
1208            params.append(null)
1209        if decode is None:
1210            if format == 'binary':
1211                decode = False
1212            else:
1213                decode = str is unicode
1214        else:
1215            if not isinstance(decode, (int, bool)):
1216                raise TypeError("The decode option must be a boolean")
1217            if decode and binary_format:
1218                raise ValueError(
1219                    "The decode option is not allowed with binary format")
1220        if columns:
1221            if not isinstance(columns, basestring):
1222                columns = ','.join('"%s"' % (col,) for col in columns)
1223            operation.append('(%s)' % (columns,))
1224
1225        operation.append("to stdout")
1226        if options:
1227            operation.append('(%s)' % (','.join(options),))
1228        operation = ' '.join(operation)
1229
1230        getdata = self._src.getdata
1231        self.execute(operation, params)
1232
1233        def copy():
1234            self.rowcount = 0
1235            while True:
1236                row = getdata(decode)
1237                if isinstance(row, int):
1238                    if self.rowcount != row:
1239                        self.rowcount = row
1240                    break
1241                self.rowcount += 1
1242                yield row
1243
1244        if stream is None:
1245            # no input stream, return the generator
1246            return copy()
1247
1248        # write the rows to the file-like input stream
1249        for row in copy():
1250            write(row)
1251
1252        # return the cursor object, so you can chain operations
1253        return self
1254
1255    def __next__(self):
1256        """Return the next row (support for the iteration protocol)."""
1257        res = self.fetchone()
1258        if res is None:
1259            raise StopIteration
1260        return res
1261
1262    # Note that since Python 2.6 the iterator protocol uses __next()__
1263    # instead of next(), we keep it only for backward compatibility of pgdb.
1264    next = __next__
1265
1266    @staticmethod
1267    def nextset():
1268        """Not supported."""
1269        raise NotSupportedError("The nextset() method is not supported")
1270
1271    @staticmethod
1272    def setinputsizes(sizes):
1273        """Not supported."""
1274        pass  # unsupported, but silently passed
1275
1276    @staticmethod
1277    def setoutputsize(size, column=0):
1278        """Not supported."""
1279        pass  # unsupported, but silently passed
1280
1281    @staticmethod
1282    def row_factory(row):
1283        """Process rows before they are returned.
1284
1285        You can overwrite this statically with a custom row factory, or
1286        you can build a row factory dynamically with build_row_factory().
1287
1288        For example, you can create a Cursor class that returns rows as
1289        Python dictionaries like this:
1290
1291            class DictCursor(pgdb.Cursor):
1292
1293                def row_factory(self, row):
1294                    return {desc[0]: value
1295                        for desc, value in zip(self.description, row)}
1296
1297            cur = DictCursor(con)  # get one DictCursor instance or
1298            con.cursor_type = DictCursor  # always use DictCursor instances
1299        """
1300        raise NotImplementedError
1301
1302    def build_row_factory(self):
1303        """Build a row factory based on the current description.
1304
1305        This implementation builds a row factory for creating named tuples.
1306        You can overwrite this method if you want to dynamically create
1307        different row factories whenever the column description changes.
1308        """
1309        colnames = self.colnames
1310        if colnames:
1311            try:
1312                try:
1313                    return namedtuple('Row', colnames, rename=True)._make
1314                except TypeError:  # Python 2.6 and 3.0 do not support rename
1315                    colnames = [v if v.isalnum() else 'column_%d' % (n,)
1316                             for n, v in enumerate(colnames)]
1317                    return namedtuple('Row', colnames)._make
1318            except ValueError:  # there is still a problem with the field names
1319                colnames = ['column_%d' % (n,) for n in range(len(colnames))]
1320                return namedtuple('Row', colnames)._make
1321
1322
1323CursorDescription = namedtuple('CursorDescription',
1324    ['name', 'type_code', 'display_size', 'internal_size',
1325     'precision', 'scale', 'null_ok'])
1326
1327
1328### Connection Objects
1329
1330class Connection(object):
1331    """Connection object."""
1332
1333    # expose the exceptions as attributes on the connection object
1334    Error = Error
1335    Warning = Warning
1336    InterfaceError = InterfaceError
1337    DatabaseError = DatabaseError
1338    InternalError = InternalError
1339    OperationalError = OperationalError
1340    ProgrammingError = ProgrammingError
1341    IntegrityError = IntegrityError
1342    DataError = DataError
1343    NotSupportedError = NotSupportedError
1344
1345    def __init__(self, cnx):
1346        """Create a database connection object."""
1347        self._cnx = cnx  # connection
1348        self._tnx = False  # transaction state
1349        self.type_cache = TypeCache(cnx)
1350        self.cursor_type = Cursor
1351        try:
1352            self._cnx.source()
1353        except Exception:
1354            raise _op_error("Invalid connection")
1355
1356    def __enter__(self):
1357        """Enter the runtime context for the connection object.
1358
1359        The runtime context can be used for running transactions.
1360        """
1361        return self
1362
1363    def __exit__(self, et, ev, tb):
1364        """Exit the runtime context for the connection object.
1365
1366        This does not close the connection, but it ends a transaction.
1367        """
1368        if et is None and ev is None and tb is None:
1369            self.commit()
1370        else:
1371            self.rollback()
1372
1373    def close(self):
1374        """Close the connection object."""
1375        if self._cnx:
1376            if self._tnx:
1377                try:
1378                    self.rollback()
1379                except DatabaseError:
1380                    pass
1381            self._cnx.close()
1382            self._cnx = None
1383        else:
1384            raise _op_error("Connection has been closed")
1385
1386    @property
1387    def closed(self):
1388        """Check whether the connection has been closed or is broken."""
1389        try:
1390            return not self._cnx or self._cnx.status != 1
1391        except TypeError:
1392            return True
1393
1394    def commit(self):
1395        """Commit any pending transaction to the database."""
1396        if self._cnx:
1397            if self._tnx:
1398                self._tnx = False
1399                try:
1400                    self._cnx.source().execute("COMMIT")
1401                except DatabaseError:
1402                    raise
1403                except Exception:
1404                    raise _op_error("Can't commit")
1405        else:
1406            raise _op_error("Connection has been closed")
1407
1408    def rollback(self):
1409        """Roll back to the start of any pending transaction."""
1410        if self._cnx:
1411            if self._tnx:
1412                self._tnx = False
1413                try:
1414                    self._cnx.source().execute("ROLLBACK")
1415                except DatabaseError:
1416                    raise
1417                except Exception:
1418                    raise _op_error("Can't rollback")
1419        else:
1420            raise _op_error("Connection has been closed")
1421
1422    def cursor(self):
1423        """Return a new cursor object using the connection."""
1424        if self._cnx:
1425            try:
1426                return self.cursor_type(self)
1427            except Exception:
1428                raise _op_error("Invalid connection")
1429        else:
1430            raise _op_error("Connection has been closed")
1431
1432    if shortcutmethods:  # otherwise do not implement and document this
1433
1434        def execute(self, operation, params=None):
1435            """Shortcut method to run an operation on an implicit cursor."""
1436            cursor = self.cursor()
1437            cursor.execute(operation, params)
1438            return cursor
1439
1440        def executemany(self, operation, param_seq):
1441            """Shortcut method to run an operation against a sequence."""
1442            cursor = self.cursor()
1443            cursor.executemany(operation, param_seq)
1444            return cursor
1445
1446
1447### Module Interface
1448
1449_connect = connect
1450
1451def connect(dsn=None,
1452        user=None, password=None,
1453        host=None, database=None, **kwargs):
1454    """Connect to a database."""
1455    # first get params from DSN
1456    dbport = -1
1457    dbhost = ""
1458    dbname = ""
1459    dbuser = ""
1460    dbpasswd = ""
1461    dbopt = ""
1462    try:
1463        params = dsn.split(":")
1464        dbhost = params[0]
1465        dbname = params[1]
1466        dbuser = params[2]
1467        dbpasswd = params[3]
1468        dbopt = params[4]
1469    except (AttributeError, IndexError, TypeError):
1470        pass
1471
1472    # override if necessary
1473    if user is not None:
1474        dbuser = user
1475    if password is not None:
1476        dbpasswd = password
1477    if database is not None:
1478        dbname = database
1479    if host is not None:
1480        try:
1481            params = host.split(":")
1482            dbhost = params[0]
1483            dbport = int(params[1])
1484        except (AttributeError, IndexError, TypeError, ValueError):
1485            pass
1486
1487    # empty host is localhost
1488    if dbhost == "":
1489        dbhost = None
1490    if dbuser == "":
1491        dbuser = None
1492
1493    # pass keyword arguments as connection info string
1494    if kwargs:
1495        kwargs = list(kwargs.items())
1496        if '=' in dbname:
1497            dbname = [dbname]
1498        else:
1499            kwargs.insert(0, ('dbname', dbname))
1500            dbname = []
1501        for kw, value in kwargs:
1502            value = str(value)
1503            if not value or ' ' in value:
1504                value = "'%s'" % (value.replace(
1505                    "'", "\\'").replace('\\', '\\\\'),)
1506            dbname.append('%s=%s' % (kw, value))
1507        dbname = ' '.join(dbname)
1508
1509    # open the connection
1510    cnx = _connect(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd)
1511    return Connection(cnx)
1512
1513
1514### Types Handling
1515
1516class Type(frozenset):
1517    """Type class for a couple of PostgreSQL data types.
1518
1519    PostgreSQL is object-oriented: types are dynamic.
1520    We must thus use type names as internal type codes.
1521    """
1522
1523    def __new__(cls, values):
1524        if isinstance(values, basestring):
1525            values = values.split()
1526        return super(Type, cls).__new__(cls, values)
1527
1528    def __eq__(self, other):
1529        if isinstance(other, basestring):
1530            if other.startswith('_'):
1531                other = other[1:]
1532            return other in self
1533        else:
1534            return super(Type, self).__eq__(other)
1535
1536    def __ne__(self, other):
1537        if isinstance(other, basestring):
1538            if other.startswith('_'):
1539                other = other[1:]
1540            return other not in self
1541        else:
1542            return super(Type, self).__ne__(other)
1543
1544
1545class ArrayType:
1546    """Type class for PostgreSQL array types."""
1547
1548    def __eq__(self, other):
1549        if isinstance(other, basestring):
1550            return other.startswith('_')
1551        else:
1552            return isinstance(other, ArrayType)
1553
1554    def __ne__(self, other):
1555        if isinstance(other, basestring):
1556            return not other.startswith('_')
1557        else:
1558            return not isinstance(other, ArrayType)
1559
1560
1561class RecordType:
1562    """Type class for PostgreSQL record types."""
1563
1564    def __eq__(self, other):
1565        if isinstance(other, TypeCode):
1566            return other.type == 'c'
1567        elif isinstance(other, basestring):
1568            return other == 'record'
1569        else:
1570            return isinstance(other, RecordType)
1571
1572    def __ne__(self, other):
1573        if isinstance(other, TypeCode):
1574            return other.type != 'c'
1575        elif isinstance(other, basestring):
1576            return other != 'record'
1577        else:
1578            return not isinstance(other, RecordType)
1579
1580
1581# Mandatory type objects defined by DB-API 2 specs:
1582
1583STRING = Type('char bpchar name text varchar')
1584BINARY = Type('bytea')
1585NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1586DATETIME = Type('date time timetz timestamp timestamptz interval'
1587    ' abstime reltime')  # these are very old
1588ROWID = Type('oid')
1589
1590
1591# Additional type objects (more specific):
1592
1593BOOL = Type('bool')
1594SMALLINT = Type('int2')
1595INTEGER = Type('int2 int4 int8 serial')
1596LONG = Type('int8')
1597FLOAT = Type('float4 float8')
1598NUMERIC = Type('numeric')
1599MONEY = Type('money')
1600DATE = Type('date')
1601TIME = Type('time timetz')
1602TIMESTAMP = Type('timestamp timestamptz')
1603INTERVAL = Type('interval')
1604UUID = Type('uuid')
1605HSTORE = Type('hstore')
1606JSON = Type('json jsonb')
1607
1608# Type object for arrays (also equate to their base types):
1609
1610ARRAY = ArrayType()
1611
1612# Type object for records (encompassing all composite types):
1613
1614RECORD = RecordType()
1615
1616
1617# Mandatory type helpers defined by DB-API 2 specs:
1618
1619def Date(year, month, day):
1620    """Construct an object holding a date value."""
1621    return date(year, month, day)
1622
1623
1624def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None):
1625    """Construct an object holding a time value."""
1626    return time(hour, minute, second, microsecond, tzinfo)
1627
1628
1629def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0,
1630        tzinfo=None):
1631    """Construct an object holding a time stamp value."""
1632    return datetime(year, month, day, hour, minute, second, microsecond, tzinfo)
1633
1634
1635def DateFromTicks(ticks):
1636    """Construct an object holding a date value from the given ticks value."""
1637    return Date(*localtime(ticks)[:3])
1638
1639
1640def TimeFromTicks(ticks):
1641    """Construct an object holding a time value from the given ticks value."""
1642    return Time(*localtime(ticks)[3:6])
1643
1644
1645def TimestampFromTicks(ticks):
1646    """Construct an object holding a time stamp from the given ticks value."""
1647    return Timestamp(*localtime(ticks)[:6])
1648
1649
1650class Binary(bytes):
1651    """Construct an object capable of holding a binary (long) string value."""
1652
1653
1654# Additional type helpers for PyGreSQL:
1655
1656def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0):
1657    """Construct an object holding a time inverval value."""
1658    return timedelta(days, hours=hours, minutes=minutes, seconds=seconds,
1659        microseconds=microseconds)
1660
1661
1662Uuid = Uuid  # Construct an object holding a UUID value
1663
1664
1665class Hstore(dict):
1666    """Wrapper class for marking hstore values."""
1667
1668    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
1669    _re_escape = regex(r'(["\\])')
1670
1671    @classmethod
1672    def _quote(cls, s):
1673        if s is None:
1674            return 'NULL'
1675        if not s:
1676            return '""'
1677        quote = cls._re_quote.search(s)
1678        s = cls._re_escape.sub(r'\\\1', s)
1679        if quote:
1680            s = '"%s"' % (s,)
1681        return s
1682
1683    def __str__(self):
1684        q = self._quote
1685        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
1686
1687
1688class Json:
1689    """Construct a wrapper for holding an object serializable to JSON."""
1690
1691    def __init__(self, obj, encode=None):
1692        self.obj = obj
1693        self.encode = encode or jsonencode
1694
1695    def __str__(self):
1696        obj = self.obj
1697        if isinstance(obj, basestring):
1698            return obj
1699        return self.encode(obj)
1700
1701
1702class Literal:
1703    """Construct a wrapper for holding a literal SQL string."""
1704
1705    def __init__(self, sql):
1706        self.sql = sql
1707
1708    def __str__(self):
1709        return self.sql
1710
1711    __pg_repr__ = __str__
1712
1713# If run as script, print some information:
1714
1715if __name__ == '__main__':
1716    print('PyGreSQL version', version)
1717    print('')
1718    print(__doc__)
Note: See TracBrowser for help on using the repository browser.