source: trunk/pgdb.py @ 989

Last change on this file since 989 was 989, checked in by cito, 3 months ago

Minor whitespace fixes and IDE hints

  • Property svn:keywords set to Id
File size: 60.2 KB
Line 
1#!/usr/bin/python
2#
3# $Id: pgdb.py 989 2019-04-24 15:49:20Z cito $
4#
5# PyGreSQL - a Python interface for the PostgreSQL database.
6#
7# This file contains the DB-API 2 compatible pgdb module.
8#
9# Copyright (c) 2019 by the PyGreSQL Development Team
10#
11# Please see the LICENSE.TXT file for specific restrictions.
12
13"""pgdb - DB-API 2.0 compliant module for PygreSQL.
14
15(c) 1999, Pascal Andre <andre@via.ecp.fr>.
16See package documentation for further information on copyright.
17
18Inline documentation is sparse.
19See DB-API 2.0 specification for usage information:
20http://www.python.org/peps/pep-0249.html
21
22Basic usage:
23
24    pgdb.connect(connect_string) # open a connection
25    # connect_string = 'host:database:user:password:opt'
26    # All parts are optional. You may also pass host through
27    # password as keyword arguments. To pass a port,
28    # pass it in the host keyword parameter:
29    connection = pgdb.connect(host='localhost:5432')
30
31    cursor = connection.cursor() # open a cursor
32
33    cursor.execute(query[, params])
34    # Execute a query, binding params (a dictionary) if they are
35    # passed. The binding syntax is the same as the % operator
36    # for dictionaries, and no quoting is done.
37
38    cursor.executemany(query, list of params)
39    # Execute a query many times, binding each param dictionary
40    # from the list.
41
42    cursor.fetchone() # fetch one row, [value, value, ...]
43
44    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
45
46    cursor.fetchmany([size])
47    # returns size or cursor.arraysize number of rows,
48    # [[value, value, ...], ...] from result set.
49    # Default cursor.arraysize is 1.
50
51    cursor.description # returns information about the columns
52    #   [(column_name, type_name, display_size,
53    #           internal_size, precision, scale, null_ok), ...]
54    # Note that display_size, precision, scale and null_ok
55    # are not implemented.
56
57    cursor.rowcount # number of rows available in the result set
58    # Available after a call to execute.
59
60    connection.commit() # commit transaction
61
62    connection.rollback() # or rollback transaction
63
64    cursor.close() # close the cursor
65
66    connection.close() # close the connection
67"""
68
69from __future__ import print_function, division
70
71from _pg import *
72
73__version__ = version
74
75from datetime import date, time, datetime, timedelta, tzinfo
76from time import localtime
77from decimal import Decimal
78from uuid import UUID as Uuid
79from math import isnan, isinf
80try:
81    from collections.abc import Iterable
82except ImportError:  # Python < 3.3
83    from collections import Iterable
84from collections import namedtuple
85from keyword import iskeyword
86from functools import partial
87from re import compile as regex
88from json import loads as jsondecode, dumps as jsonencode
89
90try:  # noinspection PyUnresolvedReferences
91    long
92except NameError:  # Python >= 3.0
93    long = int
94
95try:  # noinspection PyUnresolvedReferences
96    unicode
97except NameError:  # Python >= 3.0
98    unicode = str
99
100try:  # noinspection PyUnresolvedReferences
101    basestring
102except NameError:  # Python >= 3.0
103    basestring = (str, bytes)
104
105try:
106    from functools import lru_cache
107except ImportError:  # Python < 3.2
108    from functools import update_wrapper
109    try:
110        from _thread import RLock
111    except ImportError:
112        class RLock:  # for builds without threads
113            def __enter__(self): pass
114
115            def __exit__(self, exctype, excinst, exctb): pass
116
117    def lru_cache(maxsize=128):
118        """Simplified functools.lru_cache decorator for one argument."""
119
120        def decorator(function):
121            sentinel = object()
122            cache = {}
123            get = cache.get
124            lock = RLock()
125            root = []
126            root_full = [root, False]
127            root[:] = [root, root, None, None]
128
129            if maxsize == 0:
130
131                def wrapper(arg):
132                    res = function(arg)
133                    return res
134
135            elif maxsize is None:
136
137                def wrapper(arg):
138                    res = get(arg, sentinel)
139                    if res is not sentinel:
140                        return res
141                    res = function(arg)
142                    cache[arg] = res
143                    return res
144
145            else:
146
147                def wrapper(arg):
148                    with lock:
149                        link = get(arg)
150                        if link is not None:
151                            root = root_full[0]
152                            prev, next, _arg, res = link
153                            prev[1] = next
154                            next[0] = prev
155                            last = root[0]
156                            last[1] = root[0] = link
157                            link[0] = last
158                            link[1] = root
159                            return res
160                    res = function(arg)
161                    with lock:
162                        root, full = root_full
163                        if arg in cache:
164                            pass
165                        elif full:
166                            oldroot = root
167                            oldroot[2] = arg
168                            oldroot[3] = res
169                            root = root_full[0] = oldroot[1]
170                            oldarg = root[2]
171                            oldres = root[3]  # keep reference
172                            root[2] = root[3] = None
173                            del cache[oldarg]
174                            cache[arg] = oldroot
175                        else:
176                            last = root[0]
177                            link = [last, root, arg, res]
178                            last[1] = root[0] = cache[arg] = link
179                            if len(cache) >= maxsize:
180                                root_full[1] = True
181                    return res
182
183            wrapper.__wrapped__ = function
184            return update_wrapper(wrapper, function)
185
186        return decorator
187
188
189### Module Constants
190
191# compliant with DB API 2.0
192apilevel = '2.0'
193
194# module may be shared, but not connections
195threadsafety = 1
196
197# this module use extended python format codes
198paramstyle = 'pyformat'
199
200# shortcut methods have been excluded from DB API 2 and
201# are not recommended by the DB SIG, but they can be handy
202shortcutmethods = 1
203
204
205### Internal Type Handling
206
207try:
208    from inspect import signature
209except ImportError:  # Python < 3.3
210    from inspect import getargspec
211
212    def get_args(func):
213        return getargspec(func).args
214else:
215
216    def get_args(func):
217        return list(signature(func).parameters)
218
219try:
220    from datetime import timezone
221except ImportError:  # Python < 3.2
222
223    class timezone(tzinfo):
224        """Simple timezone implementation."""
225
226        def __init__(self, offset, name=None):
227            self.offset = offset
228            if not name:
229                minutes = self.offset.days * 1440 + self.offset.seconds // 60
230                if minutes < 0:
231                    hours, minutes = divmod(-minutes, 60)
232                    hours = -hours
233                else:
234                    hours, minutes = divmod(minutes, 60)
235                name = 'UTC%+03d:%02d' % (hours, minutes)
236            self.name = name
237
238        def utcoffset(self, dt):
239            return self.offset
240
241        def tzname(self, dt):
242            return self.name
243
244        def dst(self, dt):
245            return None
246
247    timezone.utc = timezone(timedelta(0), 'UTC')
248
249    _has_timezone = False
250else:
251    _has_timezone = True
252
253# time zones used in Postgres timestamptz output
254_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
255    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
256    UCT='+0000', UTC='+0000', WET='+0000')
257
258
259def _timezone_as_offset(tz):
260    if tz.startswith(('+', '-')):
261        if len(tz) < 5:
262            return tz + '00'
263        return tz.replace(':', '')
264    return _timezones.get(tz, '+0000')
265
266
267def _get_timezone(tz):
268    tz = _timezone_as_offset(tz)
269    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
270    if tz[0] == '-':
271        minutes = -minutes
272    return timezone(timedelta(minutes=minutes), tz)
273
274
275def decimal_type(decimal_type=None):
276    """Get or set global type to be used for decimal values.
277
278    Note that connections cache cast functions. To be sure a global change
279    is picked up by a running connection, call con.type_cache.reset_typecast().
280    """
281    global Decimal
282    if decimal_type is not None:
283        Decimal = decimal_type
284        set_typecast('numeric', decimal_type)
285    return Decimal
286
287
288def cast_bool(value):
289    """Cast boolean value in database format to bool."""
290    if value:
291        return value[0] in ('t', 'T')
292
293
294def cast_money(value):
295    """Cast money value in database format to Decimal."""
296    if value:
297        value = value.replace('(', '-')
298        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
299
300
301def cast_int2vector(value):
302    """Cast an int2vector value."""
303    return [int(v) for v in value.split()]
304
305
306def cast_date(value, connection):
307    """Cast a date value."""
308    # The output format depends on the server setting DateStyle.  The default
309    # setting ISO and the setting for German are actually unambiguous.  The
310    # order of days and months in the other two settings is however ambiguous,
311    # so at least here we need to consult the setting to properly parse values.
312    if value == '-infinity':
313        return date.min
314    if value == 'infinity':
315        return date.max
316    value = value.split()
317    if value[-1] == 'BC':
318        return date.min
319    value = value[0]
320    if len(value) > 10:
321        return date.max
322    fmt = connection.date_format()
323    return datetime.strptime(value, fmt).date()
324
325
326def cast_time(value):
327    """Cast a time value."""
328    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
329    return datetime.strptime(value, fmt).time()
330
331
332_re_timezone = regex('(.*)([+-].*)')
333
334
335def cast_timetz(value):
336    """Cast a timetz value."""
337    tz = _re_timezone.match(value)
338    if tz:
339        value, tz = tz.groups()
340    else:
341        tz = '+0000'
342    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
343    if _has_timezone:
344        value += _timezone_as_offset(tz)
345        fmt += '%z'
346        return datetime.strptime(value, fmt).timetz()
347    return datetime.strptime(value, fmt).timetz().replace(
348        tzinfo=_get_timezone(tz))
349
350
351def cast_timestamp(value, connection):
352    """Cast a timestamp value."""
353    if value == '-infinity':
354        return datetime.min
355    if value == 'infinity':
356        return datetime.max
357    value = value.split()
358    if value[-1] == 'BC':
359        return datetime.min
360    fmt = connection.date_format()
361    if fmt.endswith('-%Y') and len(value) > 2:
362        value = value[1:5]
363        if len(value[3]) > 4:
364            return datetime.max
365        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
366            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
367    else:
368        if len(value[0]) > 10:
369            return datetime.max
370        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
371    return datetime.strptime(' '.join(value), ' '.join(fmt))
372
373
374def cast_timestamptz(value, connection):
375    """Cast a timestamptz value."""
376    if value == '-infinity':
377        return datetime.min
378    if value == 'infinity':
379        return datetime.max
380    value = value.split()
381    if value[-1] == 'BC':
382        return datetime.min
383    fmt = connection.date_format()
384    if fmt.endswith('-%Y') and len(value) > 2:
385        value = value[1:]
386        if len(value[3]) > 4:
387            return datetime.max
388        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
389            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
390        value, tz = value[:-1], value[-1]
391    else:
392        if fmt.startswith('%Y-'):
393            tz = _re_timezone.match(value[1])
394            if tz:
395                value[1], tz = tz.groups()
396            else:
397                tz = '+0000'
398        else:
399            value, tz = value[:-1], value[-1]
400        if len(value[0]) > 10:
401            return datetime.max
402        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
403    if _has_timezone:
404        value.append(_timezone_as_offset(tz))
405        fmt.append('%z')
406        return datetime.strptime(' '.join(value), ' '.join(fmt))
407    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
408        tzinfo=_get_timezone(tz))
409
410
411_re_interval_sql_standard = regex(
412    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
413    '(?:([+-]?[0-9]+)(?!:) ?)?'
414    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
415
416_re_interval_postgres = regex(
417    '(?:([+-]?[0-9]+) ?years? ?)?'
418    '(?:([+-]?[0-9]+) ?mons? ?)?'
419    '(?:([+-]?[0-9]+) ?days? ?)?'
420    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
421
422_re_interval_postgres_verbose = regex(
423    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
424    '(?:([+-]?[0-9]+) ?mons? ?)?'
425    '(?:([+-]?[0-9]+) ?days? ?)?'
426    '(?:([+-]?[0-9]+) ?hours? ?)?'
427    '(?:([+-]?[0-9]+) ?mins? ?)?'
428    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
429
430_re_interval_iso_8601 = regex(
431    'P(?:([+-]?[0-9]+)Y)?'
432    '(?:([+-]?[0-9]+)M)?'
433    '(?:([+-]?[0-9]+)D)?'
434    '(?:T(?:([+-]?[0-9]+)H)?'
435    '(?:([+-]?[0-9]+)M)?'
436    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
437
438
439def cast_interval(value):
440    """Cast an interval value."""
441    # The output format depends on the server setting IntervalStyle, but it's
442    # not necessary to consult this setting to parse it.  It's faster to just
443    # check all possible formats, and there is no ambiguity here.
444    m = _re_interval_iso_8601.match(value)
445    if m:
446        m = [d or '0' for d in m.groups()]
447        secs_ago = m.pop(5) == '-'
448        m = [int(d) for d in m]
449        years, mons, days, hours, mins, secs, usecs = m
450        if secs_ago:
451            secs = -secs
452            usecs = -usecs
453    else:
454        m = _re_interval_postgres_verbose.match(value)
455        if m:
456            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
457            secs_ago = m.pop(5) == '-'
458            m = [-int(d) for d in m] if ago else [int(d) for d in m]
459            years, mons, days, hours, mins, secs, usecs = m
460            if secs_ago:
461                secs = - secs
462                usecs = -usecs
463        else:
464            m = _re_interval_postgres.match(value)
465            if m and any(m.groups()):
466                m = [d or '0' for d in m.groups()]
467                hours_ago = m.pop(3) == '-'
468                m = [int(d) for d in m]
469                years, mons, days, hours, mins, secs, usecs = m
470                if hours_ago:
471                    hours = -hours
472                    mins = -mins
473                    secs = -secs
474                    usecs = -usecs
475            else:
476                m = _re_interval_sql_standard.match(value)
477                if m and any(m.groups()):
478                    m = [d or '0' for d in m.groups()]
479                    years_ago = m.pop(0) == '-'
480                    hours_ago = m.pop(3) == '-'
481                    m = [int(d) for d in m]
482                    years, mons, days, hours, mins, secs, usecs = m
483                    if years_ago:
484                        years = -years
485                        mons = -mons
486                    if hours_ago:
487                        hours = -hours
488                        mins = -mins
489                        secs = -secs
490                        usecs = -usecs
491                else:
492                    raise ValueError('Cannot parse interval: %s' % value)
493    days += 365 * years + 30 * mons
494    return timedelta(days=days, hours=hours, minutes=mins,
495        seconds=secs, microseconds=usecs)
496
497
498class Typecasts(dict):
499    """Dictionary mapping database types to typecast functions.
500
501    The cast functions get passed the string representation of a value in
502    the database which they need to convert to a Python object.  The
503    passed string will never be None since NULL values are already
504    handled before the cast function is called.
505    """
506
507    # the default cast functions
508    # (str functions are ignored but have been added for faster access)
509    defaults = {'char': str, 'bpchar': str, 'name': str,
510        'text': str, 'varchar': str,
511        'bool': cast_bool, 'bytea': unescape_bytea,
512        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
513        'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode,
514        'float4': float, 'float8': float,
515        'numeric': Decimal, 'money': cast_money,
516        'date': cast_date, 'interval': cast_interval,
517        'time': cast_time, 'timetz': cast_timetz,
518        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
519        'int2vector': cast_int2vector, 'uuid': Uuid,
520        'anyarray': cast_array, 'record': cast_record}
521
522    connection = None  # will be set in local connection specific instances
523
524    def __missing__(self, typ):
525        """Create a cast function if it is not cached.
526
527        Note that this class never raises a KeyError,
528        but returns None when no special cast function exists.
529        """
530        if not isinstance(typ, str):
531            raise TypeError('Invalid type: %s' % typ)
532        cast = self.defaults.get(typ)
533        if cast:
534            # store default for faster access
535            cast = self._add_connection(cast)
536            self[typ] = cast
537        elif typ.startswith('_'):
538            # create array cast
539            base_cast = self[typ[1:]]
540            cast = self.create_array_cast(base_cast)
541            if base_cast:
542                # store only if base type exists
543                self[typ] = cast
544        return cast
545
546    @staticmethod
547    def _needs_connection(func):
548        """Check if a typecast function needs a connection argument."""
549        try:
550            args = get_args(func)
551        except (TypeError, ValueError):
552            return False
553        else:
554            return 'connection' in args[1:]
555
556    def _add_connection(self, cast):
557        """Add a connection argument to the typecast function if necessary."""
558        if not self.connection or not self._needs_connection(cast):
559            return cast
560        return partial(cast, connection=self.connection)
561
562    def get(self, typ, default=None):
563        """Get the typecast function for the given database type."""
564        return self[typ] or default
565
566    def set(self, typ, cast):
567        """Set a typecast function for the specified database type(s)."""
568        if isinstance(typ, basestring):
569            typ = [typ]
570        if cast is None:
571            for t in typ:
572                self.pop(t, None)
573                self.pop('_%s' % t, None)
574        else:
575            if not callable(cast):
576                raise TypeError("Cast parameter must be callable")
577            for t in typ:
578                self[t] = self._add_connection(cast)
579                self.pop('_%s' % t, None)
580
581    def reset(self, typ=None):
582        """Reset the typecasts for the specified type(s) to their defaults.
583
584        When no type is specified, all typecasts will be reset.
585        """
586        defaults = self.defaults
587        if typ is None:
588            self.clear()
589            self.update(defaults)
590        else:
591            if isinstance(typ, basestring):
592                typ = [typ]
593            for t in typ:
594                cast = defaults.get(t)
595                if cast:
596                    self[t] = self._add_connection(cast)
597                    t = '_%s' % t
598                    cast = defaults.get(t)
599                    if cast:
600                        self[t] = self._add_connection(cast)
601                    else:
602                        self.pop(t, None)
603                else:
604                    self.pop(t, None)
605                    self.pop('_%s' % t, None)
606
607    def create_array_cast(self, basecast):
608        """Create an array typecast for the given base cast."""
609        cast_array = self['anyarray']
610        def cast(v):
611            return cast_array(v, basecast)
612        return cast
613
614    def create_record_cast(self, name, fields, casts):
615        """Create a named record typecast for the given fields and casts."""
616        cast_record = self['record']
617        record = namedtuple(name, fields)
618        def cast(v):
619            return record(*cast_record(v, casts))
620        return cast
621
622
623_typecasts = Typecasts()  # this is the global typecast dictionary
624
625
626def get_typecast(typ):
627    """Get the global typecast function for the given database type(s)."""
628    return _typecasts.get(typ)
629
630
631def set_typecast(typ, cast):
632    """Set a global typecast function for the given database type(s).
633
634    Note that connections cache cast functions. To be sure a global change
635    is picked up by a running connection, call con.type_cache.reset_typecast().
636    """
637    _typecasts.set(typ, cast)
638
639
640def reset_typecast(typ=None):
641    """Reset the global typecasts for the given type(s) to their default.
642
643    When no type is specified, all typecasts will be reset.
644
645    Note that connections cache cast functions. To be sure a global change
646    is picked up by a running connection, call con.type_cache.reset_typecast().
647    """
648    _typecasts.reset(typ)
649
650
651class LocalTypecasts(Typecasts):
652    """Map typecasts, including local composite types, to cast functions."""
653
654    defaults = _typecasts
655
656    connection = None  # will be set in a connection specific instance
657
658    def __missing__(self, typ):
659        """Create a cast function if it is not cached."""
660        if typ.startswith('_'):
661            base_cast = self[typ[1:]]
662            cast = self.create_array_cast(base_cast)
663            if base_cast:
664                self[typ] = cast
665        else:
666            cast = self.defaults.get(typ)
667            if cast:
668                cast = self._add_connection(cast)
669                self[typ] = cast
670            else:
671                fields = self.get_fields(typ)
672                if fields:
673                    casts = [self[field.type] for field in fields]
674                    fields = [field.name for field in fields]
675                    cast = self.create_record_cast(typ, fields, casts)
676                    self[typ] = cast
677        return cast
678
679    def get_fields(self, typ):
680        """Return the fields for the given record type.
681
682        This method will be replaced with a method that looks up the fields
683        using the type cache of the connection.
684        """
685        return []
686
687
688class TypeCode(str):
689    """Class representing the type_code used by the DB-API 2.0.
690
691    TypeCode objects are strings equal to the PostgreSQL type name,
692    but carry some additional information.
693    """
694
695    @classmethod
696    def create(cls, oid, name, len, type, category, delim, relid):
697        """Create a type code for a PostgreSQL data type."""
698        self = cls(name)
699        self.oid = oid
700        self.len = len
701        self.type = type
702        self.category = category
703        self.delim = delim
704        self.relid = relid
705        return self
706
707FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
708
709
710class TypeCache(dict):
711    """Cache for database types.
712
713    This cache maps type OIDs and names to TypeCode strings containing
714    important information on the associated database type.
715    """
716
717    def __init__(self, cnx):
718        """Initialize type cache for connection."""
719        super(TypeCache, self).__init__()
720        self._escape_string = cnx.escape_string
721        self._src = cnx.source()
722        self._typecasts = LocalTypecasts()
723        self._typecasts.get_fields = self.get_fields
724        self._typecasts.connection = cnx
725        if cnx.server_version < 80400:
726            # older remote databases (not officially supported)
727            self._query_pg_type = ("SELECT oid, typname,"
728                " typlen, typtype, null as typcategory, typdelim, typrelid"
729                " FROM pg_type WHERE oid=%s")
730        else:
731            self._query_pg_type = ("SELECT oid, typname,"
732                " typlen, typtype, typcategory, typdelim, typrelid"
733                " FROM pg_type WHERE oid=%s")
734
735    def __missing__(self, key):
736        """Get the type info from the database if it is not cached."""
737        if isinstance(key, int):
738            oid = key
739        else:
740            if '.' not in key and '"' not in key:
741                key = '"%s"' % (key,)
742            oid = "'%s'::regtype" % (self._escape_string(key),)
743        try:
744            self._src.execute(self._query_pg_type % (oid,))
745        except ProgrammingError:
746            res = None
747        else:
748            res = self._src.fetch(1)
749        if not res:
750            raise KeyError('Type %s could not be found' % (key,))
751        res = res[0]
752        type_code = TypeCode.create(int(res[0]), res[1],
753            int(res[2]), res[3], res[4], res[5], int(res[6]))
754        self[type_code.oid] = self[str(type_code)] = type_code
755        return type_code
756
757    def get(self, key, default=None):
758        """Get the type even if it is not cached."""
759        try:
760            return self[key]
761        except KeyError:
762            return default
763
764    def get_fields(self, typ):
765        """Get the names and types of the fields of composite types."""
766        if not isinstance(typ, TypeCode):
767            typ = self.get(typ)
768            if not typ:
769                return None
770        if not typ.relid:
771            return None  # this type is not composite
772        self._src.execute("SELECT attname, atttypid"
773            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
774            " AND NOT attisdropped ORDER BY attnum" % (typ.relid,))
775        return [FieldInfo(name, self.get(int(oid)))
776            for name, oid in self._src.fetch(-1)]
777
778    def get_typecast(self, typ):
779        """Get the typecast function for the given database type."""
780        return self._typecasts.get(typ)
781
782    def set_typecast(self, typ, cast):
783        """Set a typecast function for the specified database type(s)."""
784        self._typecasts.set(typ, cast)
785
786    def reset_typecast(self, typ=None):
787        """Reset the typecast function for the specified database type(s)."""
788        self._typecasts.reset(typ)
789
790    def typecast(self, value, typ):
791        """Cast the given value according to the given database type."""
792        if value is None:
793            # for NULL values, no typecast is necessary
794            return None
795        cast = self.get_typecast(typ)
796        if not cast or cast is str:
797            # no typecast is necessary
798            return value
799        return cast(value)
800
801
802class _quotedict(dict):
803    """Dictionary with auto quoting of its items.
804
805    The quote attribute must be set to the desired quote function.
806    """
807
808    def __getitem__(self, key):
809        return self.quote(super(_quotedict, self).__getitem__(key))
810
811
812### Error Messages
813
814def _db_error(msg, cls=DatabaseError):
815    """Return DatabaseError with empty sqlstate attribute."""
816    error = cls(msg)
817    error.sqlstate = None
818    return error
819
820
821def _op_error(msg):
822    """Return OperationalError."""
823    return _db_error(msg, OperationalError)
824
825
826### Row Tuples
827
828_re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$')
829
830# The result rows for database operations are returned as named tuples
831# by default. Since creating namedtuple classes is a somewhat expensive
832# operation, we cache up to 1024 of these classes by default.
833
834@lru_cache(maxsize=1024)
835def _row_factory(names):
836    """Get a namedtuple factory for row results with the given names."""
837    try:
838        try:
839            return namedtuple('Row', names, rename=True)._make
840        except TypeError:  # Python 2.6 and 3.0 do not support rename
841            names = [v if _re_fieldname.match(v) and not iskeyword(v)
842                        else 'column_%d' % (n,)
843                     for n, v in enumerate(names)]
844            return namedtuple('Row', names)._make
845    except ValueError:  # there is still a problem with the field names
846        names = ['column_%d' % (n,) for n in range(len(names))]
847        return namedtuple('Row', names)._make
848
849
850def set_row_factory_size(maxsize):
851    """Change the size of the namedtuple factory cache.
852
853    If maxsize is set to None, the cache can grow without bound.
854    """
855    global _row_factory
856    _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__)
857
858
859### Cursor Object
860
861class Cursor(object):
862    """Cursor object."""
863
864    def __init__(self, dbcnx):
865        """Create a cursor object for the database connection."""
866        self.connection = self._dbcnx = dbcnx
867        self._cnx = dbcnx._cnx
868        self.type_cache = dbcnx.type_cache
869        self._src = self._cnx.source()
870        # the official attribute for describing the result columns
871        self._description = None
872        if self.row_factory is Cursor.row_factory:
873            # the row factory needs to be determined dynamically
874            self.row_factory = None
875        else:
876            self.build_row_factory = None
877        self.rowcount = -1
878        self.arraysize = 1
879        self.lastrowid = None
880
881    def __iter__(self):
882        """Make cursor compatible to the iteration protocol."""
883        return self
884
885    def __enter__(self):
886        """Enter the runtime context for the cursor object."""
887        return self
888
889    def __exit__(self, et, ev, tb):
890        """Exit the runtime context for the cursor object."""
891        self.close()
892
893    def _quote(self, value):
894        """Quote value depending on its type."""
895        if value is None:
896            return 'NULL'
897        if isinstance(value, (Hstore, Json)):
898            value = str(value)
899        if isinstance(value, basestring):
900            if isinstance(value, Binary):
901                value = self._cnx.escape_bytea(value)
902                if bytes is not str:  # Python >= 3.0
903                    value = value.decode('ascii')
904            else:
905                value = self._cnx.escape_string(value)
906            return "'%s'" % (value,)
907        if isinstance(value, float):
908            if isinf(value):
909                return "'-Infinity'" if value < 0 else "'Infinity'"
910            if isnan(value):
911                return "'NaN'"
912            return value
913        if isinstance(value, (int, long, Decimal, Literal)):
914            return value
915        if isinstance(value, datetime):
916            if value.tzinfo:
917                return "'%s'::timestamptz" % (value,)
918            return "'%s'::timestamp" % (value,)
919        if isinstance(value, date):
920            return "'%s'::date" % (value,)
921        if isinstance(value, time):
922            if value.tzinfo:
923                return "'%s'::timetz" % (value,)
924            return "'%s'::time" % value
925        if isinstance(value, timedelta):
926            return "'%s'::interval" % (value,)
927        if isinstance(value, Uuid):
928            return "'%s'::uuid" % (value,)
929        if isinstance(value, list):
930            # Quote value as an ARRAY constructor. This is better than using
931            # an array literal because it carries the information that this is
932            # an array and not a string.  One issue with this syntax is that
933            # you need to add an explicit typecast when passing empty arrays.
934            # The ARRAY keyword is actually only necessary at the top level.
935            if not value:  # exception for empty array
936                return "'{}'"
937            q = self._quote
938            try:
939                return 'ARRAY[%s]' % (','.join(str(q(v)) for v in value),)
940            except UnicodeEncodeError:  # Python 2 with non-ascii values
941                return u'ARRAY[%s]' % (','.join(unicode(q(v)) for v in value),)
942        if isinstance(value, tuple):
943            # Quote as a ROW constructor.  This is better than using a record
944            # literal because it carries the information that this is a record
945            # and not a string.  We don't use the keyword ROW in order to make
946            # this usable with the IN syntax as well.  It is only necessary
947            # when the records has a single column which is not really useful.
948            q = self._quote
949            try:
950                return '(%s)' % (','.join(str(q(v)) for v in value),)
951            except UnicodeEncodeError:  # Python 2 with non-ascii values
952                return u'(%s)' % (','.join(unicode(q(v)) for v in value),)
953        try:
954            value = value.__pg_repr__()
955        except AttributeError:
956            raise InterfaceError(
957                'Do not know how to adapt type %s' % (type(value),))
958        if isinstance(value, (tuple, list)):
959            value = self._quote(value)
960        return value
961
962    def _quoteparams(self, string, parameters):
963        """Quote parameters.
964
965        This function works for both mappings and sequences.
966
967        The function should be used even when there are no parameters,
968        so that we have a consistent behavior regarding percent signs.
969        """
970        if not parameters:
971            try:
972                return string % ()  # unescape literal quotes if possible
973            except (TypeError, ValueError):
974                return string  # silently accept unescaped quotes
975        if isinstance(parameters, dict):
976            parameters = _quotedict(parameters)
977            parameters.quote = self._quote
978        else:
979            parameters = tuple(map(self._quote, parameters))
980        return string % parameters
981
982    def _make_description(self, info):
983        """Make the description tuple for the given field info."""
984        name, typ, size, mod = info[1:]
985        type_code = self.type_cache[typ]
986        if mod > 0:
987            mod -= 4
988        if type_code == 'numeric':
989            precision, scale = mod >> 16, mod & 0xffff
990            size = precision
991        else:
992            if not size:
993                size = type_code.size
994            if size == -1:
995                size = mod
996            precision = scale = None
997        return CursorDescription(name, type_code,
998            None, size, precision, scale, None)
999
1000    @property
1001    def description(self):
1002        """Read-only attribute describing the result columns."""
1003        descr = self._description
1004        if self._description is True:
1005            make = self._make_description
1006            descr = [make(info) for info in self._src.listinfo()]
1007            self._description = descr
1008        return descr
1009
1010    @property
1011    def colnames(self):
1012        """Unofficial convenience method for getting the column names."""
1013        return [d[0] for d in self.description]
1014
1015    @property
1016    def coltypes(self):
1017        """Unofficial convenience method for getting the column types."""
1018        return [d[1] for d in self.description]
1019
1020    def close(self):
1021        """Close the cursor object."""
1022        self._src.close()
1023
1024    def execute(self, operation, parameters=None):
1025        """Prepare and execute a database operation (query or command)."""
1026        # The parameters may also be specified as list of tuples to e.g.
1027        # insert multiple rows in a single operation, but this kind of
1028        # usage is deprecated.  We make several plausibility checks because
1029        # tuples can also be passed with the meaning of ROW constructors.
1030        if (parameters and isinstance(parameters, list)
1031                and len(parameters) > 1
1032                and all(isinstance(p, tuple) for p in parameters)
1033                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
1034            return self.executemany(operation, parameters)
1035        else:
1036            # not a list of tuples
1037            return self.executemany(operation, [parameters])
1038
1039    def executemany(self, operation, seq_of_parameters):
1040        """Prepare operation and execute it against a parameter sequence."""
1041        if not seq_of_parameters:
1042            # don't do anything without parameters
1043            return
1044        self._description = None
1045        self.rowcount = -1
1046        # first try to execute all queries
1047        rowcount = 0
1048        sql = "BEGIN"
1049        try:
1050            if not self._dbcnx._tnx:
1051                try:
1052                    self._src.execute(sql)
1053                except DatabaseError:
1054                    raise  # database provides error message
1055                except Exception:
1056                    raise _op_error("Can't start transaction")
1057                self._dbcnx._tnx = True
1058            for parameters in seq_of_parameters:
1059                sql = operation
1060                sql = self._quoteparams(sql, parameters)
1061                rows = self._src.execute(sql)
1062                if rows:  # true if not DML
1063                    rowcount += rows
1064                else:
1065                    self.rowcount = -1
1066        except DatabaseError:
1067            raise  # database provides error message
1068        except Error as err:
1069            raise _db_error(
1070                "Error in '%s': '%s' " % (sql, err), InterfaceError)
1071        except Exception as err:
1072            raise _op_error("Internal error in '%s': %s" % (sql, err))
1073        # then initialize result raw count and description
1074        if self._src.resulttype == RESULT_DQL:
1075            self._description = True  # fetch on demand
1076            self.rowcount = self._src.ntuples
1077            self.lastrowid = None
1078            if self.build_row_factory:
1079                self.row_factory = self.build_row_factory()
1080        else:
1081            self.rowcount = rowcount
1082            self.lastrowid = self._src.oidstatus()
1083        # return the cursor object, so you can write statements such as
1084        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
1085        return self
1086
1087    def fetchone(self):
1088        """Fetch the next row of a query result set."""
1089        res = self.fetchmany(1, False)
1090        try:
1091            return res[0]
1092        except IndexError:
1093            return None
1094
1095    def fetchall(self):
1096        """Fetch all (remaining) rows of a query result."""
1097        return self.fetchmany(-1, False)
1098
1099    def fetchmany(self, size=None, keep=False):
1100        """Fetch the next set of rows of a query result.
1101
1102        The number of rows to fetch per call is specified by the
1103        size parameter. If it is not given, the cursor's arraysize
1104        determines the number of rows to be fetched. If you set
1105        the keep parameter to true, this is kept as new arraysize.
1106        """
1107        if size is None:
1108            size = self.arraysize
1109        if keep:
1110            self.arraysize = size
1111        try:
1112            result = self._src.fetch(size)
1113        except DatabaseError:
1114            raise
1115        except Error as err:
1116            raise _db_error(str(err))
1117        typecast = self.type_cache.typecast
1118        return [self.row_factory([typecast(value, typ)
1119            for typ, value in zip(self.coltypes, row)]) for row in result]
1120
1121    def callproc(self, procname, parameters=None):
1122        """Call a stored database procedure with the given name.
1123
1124        The sequence of parameters must contain one entry for each input
1125        argument that the procedure expects. The result of the call is the
1126        same as this input sequence; replacement of output and input/output
1127        parameters in the return value is currently not supported.
1128
1129        The procedure may also provide a result set as output. These can be
1130        requested through the standard fetch methods of the cursor.
1131        """
1132        n = parameters and len(parameters) or 0
1133        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
1134        self.execute(query, parameters)
1135        return parameters
1136
1137    def copy_from(self, stream, table,
1138            format=None, sep=None, null=None, size=None, columns=None):
1139        """Copy data from an input stream to the specified table.
1140
1141        The input stream can be a file-like object with a read() method or
1142        it can also be an iterable returning a row or multiple rows of input
1143        on each iteration.
1144
1145        The format must be text, csv or binary. The sep option sets the
1146        column separator (delimiter) used in the non binary formats.
1147        The null option sets the textual representation of NULL in the input.
1148
1149        The size option sets the size of the buffer used when reading data
1150        from file-like objects.
1151
1152        The copy operation can be restricted to a subset of columns. If no
1153        columns are specified, all of them will be copied.
1154        """
1155        binary_format = format == 'binary'
1156        try:
1157            read = stream.read
1158        except AttributeError:
1159            if size:
1160                raise ValueError("Size must only be set for file-like objects")
1161            if binary_format:
1162                input_type = bytes
1163                type_name = 'byte strings'
1164            else:
1165                input_type = basestring
1166                type_name = 'strings'
1167
1168            if isinstance(stream, basestring):
1169                if not isinstance(stream, input_type):
1170                    raise ValueError("The input must be %s" % (type_name,))
1171                if not binary_format:
1172                    if isinstance(stream, str):
1173                        if not stream.endswith('\n'):
1174                            stream += '\n'
1175                    else:
1176                        if not stream.endswith(b'\n'):
1177                            stream += b'\n'
1178
1179                def chunks():
1180                    yield stream
1181
1182            elif isinstance(stream, Iterable):
1183
1184                def chunks():
1185                    for chunk in stream:
1186                        if not isinstance(chunk, input_type):
1187                            raise ValueError(
1188                                "Input stream must consist of %s"
1189                                % (type_name,))
1190                        if isinstance(chunk, str):
1191                            if not chunk.endswith('\n'):
1192                                chunk += '\n'
1193                        else:
1194                            if not chunk.endswith(b'\n'):
1195                                chunk += b'\n'
1196                        yield chunk
1197
1198            else:
1199                raise TypeError("Need an input stream to copy from")
1200        else:
1201            if size is None:
1202                size = 8192
1203            elif not isinstance(size, int):
1204                raise TypeError("The size option must be an integer")
1205            if size > 0:
1206
1207                def chunks():
1208                    while True:
1209                        buffer = read(size)
1210                        yield buffer
1211                        if not buffer or len(buffer) < size:
1212                            break
1213
1214            else:
1215
1216                def chunks():
1217                    yield read()
1218
1219        if not table or not isinstance(table, basestring):
1220            raise TypeError("Need a table to copy to")
1221        if table.lower().startswith('select'):
1222            raise ValueError("Must specify a table, not a query")
1223        else:
1224            table = '"%s"' % (table,)
1225        operation = ['copy %s' % (table,)]
1226        options = []
1227        params = []
1228        if format is not None:
1229            if not isinstance(format, basestring):
1230                raise TypeError("The frmat option must be be a string")
1231            if format not in ('text', 'csv', 'binary'):
1232                raise ValueError("Invalid format")
1233            options.append('format %s' % (format,))
1234        if sep is not None:
1235            if not isinstance(sep, basestring):
1236                raise TypeError("The sep option must be a string")
1237            if format == 'binary':
1238                raise ValueError(
1239                    "The sep option is not allowed with binary format")
1240            if len(sep) != 1:
1241                raise ValueError(
1242                    "The sep option must be a single one-byte character")
1243            options.append('delimiter %s')
1244            params.append(sep)
1245        if null is not None:
1246            if not isinstance(null, basestring):
1247                raise TypeError("The null option must be a string")
1248            options.append('null %s')
1249            params.append(null)
1250        if columns:
1251            if not isinstance(columns, basestring):
1252                columns = ','.join('"%s"' % (col,) for col in columns)
1253            operation.append('(%s)' % (columns,))
1254        operation.append("from stdin")
1255        if options:
1256            operation.append('(%s)' % (','.join(options),))
1257        operation = ' '.join(operation)
1258
1259        putdata = self._src.putdata
1260        self.execute(operation, params)
1261
1262        try:
1263            for chunk in chunks():
1264                putdata(chunk)
1265        except BaseException as error:
1266            self.rowcount = -1
1267            # the following call will re-raise the error
1268            putdata(error)
1269        else:
1270            self.rowcount = putdata(None)
1271
1272        # return the cursor object, so you can chain operations
1273        return self
1274
1275    def copy_to(self, stream, table,
1276            format=None, sep=None, null=None, decode=None, columns=None):
1277        """Copy data from the specified table to an output stream.
1278
1279        The output stream can be a file-like object with a write() method or
1280        it can also be None, in which case the method will return a generator
1281        yielding a row on each iteration.
1282
1283        Output will be returned as byte strings unless you set decode to true.
1284
1285        Note that you can also use a select query instead of the table name.
1286
1287        The format must be text, csv or binary. The sep option sets the
1288        column separator (delimiter) used in the non binary formats.
1289        The null option sets the textual representation of NULL in the output.
1290
1291        The copy operation can be restricted to a subset of columns. If no
1292        columns are specified, all of them will be copied.
1293        """
1294        binary_format = format == 'binary'
1295        if stream is not None:
1296            try:
1297                write = stream.write
1298            except AttributeError:
1299                raise TypeError("Need an output stream to copy to")
1300        if not table or not isinstance(table, basestring):
1301            raise TypeError("Need a table to copy to")
1302        if table.lower().startswith('select'):
1303            if columns:
1304                raise ValueError("Columns must be specified in the query")
1305            table = '(%s)' % (table,)
1306        else:
1307            table = '"%s"' % (table,)
1308        operation = ['copy %s' % (table,)]
1309        options = []
1310        params = []
1311        if format is not None:
1312            if not isinstance(format, basestring):
1313                raise TypeError("The format option must be a string")
1314            if format not in ('text', 'csv', 'binary'):
1315                raise ValueError("Invalid format")
1316            options.append('format %s' % (format,))
1317        if sep is not None:
1318            if not isinstance(sep, basestring):
1319                raise TypeError("The sep option must be a string")
1320            if binary_format:
1321                raise ValueError(
1322                    "The sep option is not allowed with binary format")
1323            if len(sep) != 1:
1324                raise ValueError(
1325                    "The sep option must be a single one-byte character")
1326            options.append('delimiter %s')
1327            params.append(sep)
1328        if null is not None:
1329            if not isinstance(null, basestring):
1330                raise TypeError("The null option must be a string")
1331            options.append('null %s')
1332            params.append(null)
1333        if decode is None:
1334            if format == 'binary':
1335                decode = False
1336            else:
1337                decode = str is unicode
1338        else:
1339            if not isinstance(decode, (int, bool)):
1340                raise TypeError("The decode option must be a boolean")
1341            if decode and binary_format:
1342                raise ValueError(
1343                    "The decode option is not allowed with binary format")
1344        if columns:
1345            if not isinstance(columns, basestring):
1346                columns = ','.join('"%s"' % (col,) for col in columns)
1347            operation.append('(%s)' % (columns,))
1348
1349        operation.append("to stdout")
1350        if options:
1351            operation.append('(%s)' % (','.join(options),))
1352        operation = ' '.join(operation)
1353
1354        getdata = self._src.getdata
1355        self.execute(operation, params)
1356
1357        def copy():
1358            self.rowcount = 0
1359            while True:
1360                row = getdata(decode)
1361                if isinstance(row, int):
1362                    if self.rowcount != row:
1363                        self.rowcount = row
1364                    break
1365                self.rowcount += 1
1366                yield row
1367
1368        if stream is None:
1369            # no input stream, return the generator
1370            return copy()
1371
1372        # write the rows to the file-like input stream
1373        for row in copy():
1374            write(row)
1375
1376        # return the cursor object, so you can chain operations
1377        return self
1378
1379    def __next__(self):
1380        """Return the next row (support for the iteration protocol)."""
1381        res = self.fetchone()
1382        if res is None:
1383            raise StopIteration
1384        return res
1385
1386    # Note that since Python 2.6 the iterator protocol uses __next()__
1387    # instead of next(), we keep it only for backward compatibility of pgdb.
1388    next = __next__
1389
1390    @staticmethod
1391    def nextset():
1392        """Not supported."""
1393        raise NotSupportedError("The nextset() method is not supported")
1394
1395    @staticmethod
1396    def setinputsizes(sizes):
1397        """Not supported."""
1398        pass  # unsupported, but silently passed
1399
1400    @staticmethod
1401    def setoutputsize(size, column=0):
1402        """Not supported."""
1403        pass  # unsupported, but silently passed
1404
1405    @staticmethod
1406    def row_factory(row):
1407        """Process rows before they are returned.
1408
1409        You can overwrite this statically with a custom row factory, or
1410        you can build a row factory dynamically with build_row_factory().
1411
1412        For example, you can create a Cursor class that returns rows as
1413        Python dictionaries like this:
1414
1415            class DictCursor(pgdb.Cursor):
1416
1417                def row_factory(self, row):
1418                    return {desc[0]: value
1419                        for desc, value in zip(self.description, row)}
1420
1421            cur = DictCursor(con)  # get one DictCursor instance or
1422            con.cursor_type = DictCursor  # always use DictCursor instances
1423        """
1424        raise NotImplementedError
1425
1426    def build_row_factory(self):
1427        """Build a row factory based on the current description.
1428
1429        This implementation builds a row factory for creating named tuples.
1430        You can overwrite this method if you want to dynamically create
1431        different row factories whenever the column description changes.
1432        """
1433        names = self.colnames
1434        if names:
1435            return _row_factory(tuple(names))
1436
1437
1438CursorDescription = namedtuple('CursorDescription',
1439    ['name', 'type_code', 'display_size', 'internal_size',
1440     'precision', 'scale', 'null_ok'])
1441
1442
1443### Connection Objects
1444
1445class Connection(object):
1446    """Connection object."""
1447
1448    # expose the exceptions as attributes on the connection object
1449    Error = Error
1450    Warning = Warning
1451    InterfaceError = InterfaceError
1452    DatabaseError = DatabaseError
1453    InternalError = InternalError
1454    OperationalError = OperationalError
1455    ProgrammingError = ProgrammingError
1456    IntegrityError = IntegrityError
1457    DataError = DataError
1458    NotSupportedError = NotSupportedError
1459
1460    def __init__(self, cnx):
1461        """Create a database connection object."""
1462        self._cnx = cnx  # connection
1463        self._tnx = False  # transaction state
1464        self.type_cache = TypeCache(cnx)
1465        self.cursor_type = Cursor
1466        try:
1467            self._cnx.source()
1468        except Exception:
1469            raise _op_error("Invalid connection")
1470
1471    def __enter__(self):
1472        """Enter the runtime context for the connection object.
1473
1474        The runtime context can be used for running transactions.
1475        """
1476        return self
1477
1478    def __exit__(self, et, ev, tb):
1479        """Exit the runtime context for the connection object.
1480
1481        This does not close the connection, but it ends a transaction.
1482        """
1483        if et is None and ev is None and tb is None:
1484            self.commit()
1485        else:
1486            self.rollback()
1487
1488    def close(self):
1489        """Close the connection object."""
1490        if self._cnx:
1491            if self._tnx:
1492                try:
1493                    self.rollback()
1494                except DatabaseError:
1495                    pass
1496            self._cnx.close()
1497            self._cnx = None
1498        else:
1499            raise _op_error("Connection has been closed")
1500
1501    @property
1502    def closed(self):
1503        """Check whether the connection has been closed or is broken."""
1504        try:
1505            return not self._cnx or self._cnx.status != 1
1506        except TypeError:
1507            return True
1508
1509    def commit(self):
1510        """Commit any pending transaction to the database."""
1511        if self._cnx:
1512            if self._tnx:
1513                self._tnx = False
1514                try:
1515                    self._cnx.source().execute("COMMIT")
1516                except DatabaseError:
1517                    raise
1518                except Exception:
1519                    raise _op_error("Can't commit")
1520        else:
1521            raise _op_error("Connection has been closed")
1522
1523    def rollback(self):
1524        """Roll back to the start of any pending transaction."""
1525        if self._cnx:
1526            if self._tnx:
1527                self._tnx = False
1528                try:
1529                    self._cnx.source().execute("ROLLBACK")
1530                except DatabaseError:
1531                    raise
1532                except Exception:
1533                    raise _op_error("Can't rollback")
1534        else:
1535            raise _op_error("Connection has been closed")
1536
1537    def cursor(self):
1538        """Return a new cursor object using the connection."""
1539        if self._cnx:
1540            try:
1541                return self.cursor_type(self)
1542            except Exception:
1543                raise _op_error("Invalid connection")
1544        else:
1545            raise _op_error("Connection has been closed")
1546
1547    if shortcutmethods:  # otherwise do not implement and document this
1548
1549        def execute(self, operation, params=None):
1550            """Shortcut method to run an operation on an implicit cursor."""
1551            cursor = self.cursor()
1552            cursor.execute(operation, params)
1553            return cursor
1554
1555        def executemany(self, operation, param_seq):
1556            """Shortcut method to run an operation against a sequence."""
1557            cursor = self.cursor()
1558            cursor.executemany(operation, param_seq)
1559            return cursor
1560
1561
1562### Module Interface
1563
1564_connect = connect
1565
1566def connect(dsn=None,
1567        user=None, password=None,
1568        host=None, database=None, **kwargs):
1569    """Connect to a database."""
1570    # first get params from DSN
1571    dbport = -1
1572    dbhost = ""
1573    dbname = ""
1574    dbuser = ""
1575    dbpasswd = ""
1576    dbopt = ""
1577    try:
1578        params = dsn.split(":")
1579        dbhost = params[0]
1580        dbname = params[1]
1581        dbuser = params[2]
1582        dbpasswd = params[3]
1583        dbopt = params[4]
1584    except (AttributeError, IndexError, TypeError):
1585        pass
1586
1587    # override if necessary
1588    if user is not None:
1589        dbuser = user
1590    if password is not None:
1591        dbpasswd = password
1592    if database is not None:
1593        dbname = database
1594    if host is not None:
1595        try:
1596            params = host.split(":")
1597            dbhost = params[0]
1598            dbport = int(params[1])
1599        except (AttributeError, IndexError, TypeError, ValueError):
1600            pass
1601
1602    # empty host is localhost
1603    if dbhost == "":
1604        dbhost = None
1605    if dbuser == "":
1606        dbuser = None
1607
1608    # pass keyword arguments as connection info string
1609    if kwargs:
1610        kwargs = list(kwargs.items())
1611        if '=' in dbname:
1612            dbname = [dbname]
1613        else:
1614            kwargs.insert(0, ('dbname', dbname))
1615            dbname = []
1616        for kw, value in kwargs:
1617            value = str(value)
1618            if not value or ' ' in value:
1619                value = "'%s'" % (value.replace(
1620                    "'", "\\'").replace('\\', '\\\\'),)
1621            dbname.append('%s=%s' % (kw, value))
1622        dbname = ' '.join(dbname)
1623
1624    # open the connection
1625    cnx = _connect(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd)
1626    return Connection(cnx)
1627
1628
1629### Types Handling
1630
1631class Type(frozenset):
1632    """Type class for a couple of PostgreSQL data types.
1633
1634    PostgreSQL is object-oriented: types are dynamic.
1635    We must thus use type names as internal type codes.
1636    """
1637
1638    def __new__(cls, values):
1639        if isinstance(values, basestring):
1640            values = values.split()
1641        return super(Type, cls).__new__(cls, values)
1642
1643    def __eq__(self, other):
1644        if isinstance(other, basestring):
1645            if other.startswith('_'):
1646                other = other[1:]
1647            return other in self
1648        else:
1649            return super(Type, self).__eq__(other)
1650
1651    def __ne__(self, other):
1652        if isinstance(other, basestring):
1653            if other.startswith('_'):
1654                other = other[1:]
1655            return other not in self
1656        else:
1657            return super(Type, self).__ne__(other)
1658
1659
1660class ArrayType:
1661    """Type class for PostgreSQL array types."""
1662
1663    def __eq__(self, other):
1664        if isinstance(other, basestring):
1665            return other.startswith('_')
1666        else:
1667            return isinstance(other, ArrayType)
1668
1669    def __ne__(self, other):
1670        if isinstance(other, basestring):
1671            return not other.startswith('_')
1672        else:
1673            return not isinstance(other, ArrayType)
1674
1675
1676class RecordType:
1677    """Type class for PostgreSQL record types."""
1678
1679    def __eq__(self, other):
1680        if isinstance(other, TypeCode):
1681            return other.type == 'c'
1682        elif isinstance(other, basestring):
1683            return other == 'record'
1684        else:
1685            return isinstance(other, RecordType)
1686
1687    def __ne__(self, other):
1688        if isinstance(other, TypeCode):
1689            return other.type != 'c'
1690        elif isinstance(other, basestring):
1691            return other != 'record'
1692        else:
1693            return not isinstance(other, RecordType)
1694
1695
1696# Mandatory type objects defined by DB-API 2 specs:
1697
1698STRING = Type('char bpchar name text varchar')
1699BINARY = Type('bytea')
1700NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1701DATETIME = Type('date time timetz timestamp timestamptz interval'
1702    ' abstime reltime')  # these are very old
1703ROWID = Type('oid')
1704
1705
1706# Additional type objects (more specific):
1707
1708BOOL = Type('bool')
1709SMALLINT = Type('int2')
1710INTEGER = Type('int2 int4 int8 serial')
1711LONG = Type('int8')
1712FLOAT = Type('float4 float8')
1713NUMERIC = Type('numeric')
1714MONEY = Type('money')
1715DATE = Type('date')
1716TIME = Type('time timetz')
1717TIMESTAMP = Type('timestamp timestamptz')
1718INTERVAL = Type('interval')
1719UUID = Type('uuid')
1720HSTORE = Type('hstore')
1721JSON = Type('json jsonb')
1722
1723# Type object for arrays (also equate to their base types):
1724
1725ARRAY = ArrayType()
1726
1727# Type object for records (encompassing all composite types):
1728
1729RECORD = RecordType()
1730
1731
1732# Mandatory type helpers defined by DB-API 2 specs:
1733
1734def Date(year, month, day):
1735    """Construct an object holding a date value."""
1736    return date(year, month, day)
1737
1738
1739def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None):
1740    """Construct an object holding a time value."""
1741    return time(hour, minute, second, microsecond, tzinfo)
1742
1743
1744def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0,
1745        tzinfo=None):
1746    """Construct an object holding a time stamp value."""
1747    return datetime(year, month, day, hour, minute, second, microsecond, tzinfo)
1748
1749
1750def DateFromTicks(ticks):
1751    """Construct an object holding a date value from the given ticks value."""
1752    return Date(*localtime(ticks)[:3])
1753
1754
1755def TimeFromTicks(ticks):
1756    """Construct an object holding a time value from the given ticks value."""
1757    return Time(*localtime(ticks)[3:6])
1758
1759
1760def TimestampFromTicks(ticks):
1761    """Construct an object holding a time stamp from the given ticks value."""
1762    return Timestamp(*localtime(ticks)[:6])
1763
1764
1765class Binary(bytes):
1766    """Construct an object capable of holding a binary (long) string value."""
1767
1768
1769# Additional type helpers for PyGreSQL:
1770
1771def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0):
1772    """Construct an object holding a time inverval value."""
1773    return timedelta(days, hours=hours, minutes=minutes, seconds=seconds,
1774        microseconds=microseconds)
1775
1776
1777Uuid = Uuid  # Construct an object holding a UUID value
1778
1779
1780class Hstore(dict):
1781    """Wrapper class for marking hstore values."""
1782
1783    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
1784    _re_escape = regex(r'(["\\])')
1785
1786    @classmethod
1787    def _quote(cls, s):
1788        if s is None:
1789            return 'NULL'
1790        if not s:
1791            return '""'
1792        quote = cls._re_quote.search(s)
1793        s = cls._re_escape.sub(r'\\\1', s)
1794        if quote:
1795            s = '"%s"' % (s,)
1796        return s
1797
1798    def __str__(self):
1799        q = self._quote
1800        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
1801
1802
1803class Json:
1804    """Construct a wrapper for holding an object serializable to JSON."""
1805
1806    def __init__(self, obj, encode=None):
1807        self.obj = obj
1808        self.encode = encode or jsonencode
1809
1810    def __str__(self):
1811        obj = self.obj
1812        if isinstance(obj, basestring):
1813            return obj
1814        return self.encode(obj)
1815
1816
1817class Literal:
1818    """Construct a wrapper for holding a literal SQL string."""
1819
1820    def __init__(self, sql):
1821        self.sql = sql
1822
1823    def __str__(self):
1824        return self.sql
1825
1826    __pg_repr__ = __str__
1827
1828
1829# If run as script, print some information:
1830
1831if __name__ == '__main__':
1832    print('PyGreSQL version', version)
1833    print('')
1834    print(__doc__)
Note: See TracBrowser for help on using the repository browser.