source: trunk/pgdb.py

Last change on this file was 995, checked in by cito, 4 months ago

Support autocommit attribute on pgdb connections

  • Property svn:keywords set to Id
File size: 60.8 KB
Line 
1#!/usr/bin/python
2#
3# $Id: pgdb.py 995 2019-04-25 14:10:20Z cito $
4#
5# PyGreSQL - a Python interface for the PostgreSQL database.
6#
7# This file contains the DB-API 2 compatible pgdb module.
8#
9# Copyright (c) 2019 by the PyGreSQL Development Team
10#
11# Please see the LICENSE.TXT file for specific restrictions.
12
13"""pgdb - DB-API 2.0 compliant module for PygreSQL.
14
15(c) 1999, Pascal Andre <andre@via.ecp.fr>.
16See package documentation for further information on copyright.
17
18Inline documentation is sparse.
19See DB-API 2.0 specification for usage information:
20http://www.python.org/peps/pep-0249.html
21
22Basic usage:
23
24    pgdb.connect(connect_string) # open a connection
25    # connect_string = 'host:database:user:password:opt'
26    # All parts are optional. You may also pass host through
27    # password as keyword arguments. To pass a port,
28    # pass it in the host keyword parameter:
29    connection = pgdb.connect(host='localhost:5432')
30
31    cursor = connection.cursor() # open a cursor
32
33    cursor.execute(query[, params])
34    # Execute a query, binding params (a dictionary) if they are
35    # passed. The binding syntax is the same as the % operator
36    # for dictionaries, and no quoting is done.
37
38    cursor.executemany(query, list of params)
39    # Execute a query many times, binding each param dictionary
40    # from the list.
41
42    cursor.fetchone() # fetch one row, [value, value, ...]
43
44    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
45
46    cursor.fetchmany([size])
47    # returns size or cursor.arraysize number of rows,
48    # [[value, value, ...], ...] from result set.
49    # Default cursor.arraysize is 1.
50
51    cursor.description # returns information about the columns
52    #   [(column_name, type_name, display_size,
53    #           internal_size, precision, scale, null_ok), ...]
54    # Note that display_size, precision, scale and null_ok
55    # are not implemented.
56
57    cursor.rowcount # number of rows available in the result set
58    # Available after a call to execute.
59
60    connection.commit() # commit transaction
61
62    connection.rollback() # or rollback transaction
63
64    cursor.close() # close the cursor
65
66    connection.close() # close the connection
67"""
68
69from __future__ import print_function, division
70
71from _pg import *
72
73__version__ = version
74
75from datetime import date, time, datetime, timedelta, tzinfo
76from time import localtime
77from decimal import Decimal
78from uuid import UUID as Uuid
79from math import isnan, isinf
80try:
81    from collections.abc import Iterable
82except ImportError:  # Python < 3.3
83    from collections import Iterable
84from collections import namedtuple
85from keyword import iskeyword
86from functools import partial
87from re import compile as regex
88from json import loads as jsondecode, dumps as jsonencode
89
90try:  # noinspection PyUnresolvedReferences
91    long
92except NameError:  # Python >= 3.0
93    long = int
94
95try:  # noinspection PyUnresolvedReferences
96    unicode
97except NameError:  # Python >= 3.0
98    unicode = str
99
100try:  # noinspection PyUnresolvedReferences
101    basestring
102except NameError:  # Python >= 3.0
103    basestring = (str, bytes)
104
105try:
106    from functools import lru_cache
107except ImportError:  # Python < 3.2
108    from functools import update_wrapper
109    try:
110        from _thread import RLock
111    except ImportError:
112        class RLock:  # for builds without threads
113            def __enter__(self): pass
114
115            def __exit__(self, exctype, excinst, exctb): pass
116
117    def lru_cache(maxsize=128):
118        """Simplified functools.lru_cache decorator for one argument."""
119
120        def decorator(function):
121            sentinel = object()
122            cache = {}
123            get = cache.get
124            lock = RLock()
125            root = []
126            root_full = [root, False]
127            root[:] = [root, root, None, None]
128
129            if maxsize == 0:
130
131                def wrapper(arg):
132                    res = function(arg)
133                    return res
134
135            elif maxsize is None:
136
137                def wrapper(arg):
138                    res = get(arg, sentinel)
139                    if res is not sentinel:
140                        return res
141                    res = function(arg)
142                    cache[arg] = res
143                    return res
144
145            else:
146
147                def wrapper(arg):
148                    with lock:
149                        link = get(arg)
150                        if link is not None:
151                            root = root_full[0]
152                            prev, next, _arg, res = link
153                            prev[1] = next
154                            next[0] = prev
155                            last = root[0]
156                            last[1] = root[0] = link
157                            link[0] = last
158                            link[1] = root
159                            return res
160                    res = function(arg)
161                    with lock:
162                        root, full = root_full
163                        if arg in cache:
164                            pass
165                        elif full:
166                            oldroot = root
167                            oldroot[2] = arg
168                            oldroot[3] = res
169                            root = root_full[0] = oldroot[1]
170                            oldarg = root[2]
171                            oldres = root[3]  # keep reference
172                            root[2] = root[3] = None
173                            del cache[oldarg]
174                            cache[arg] = oldroot
175                        else:
176                            last = root[0]
177                            link = [last, root, arg, res]
178                            last[1] = root[0] = cache[arg] = link
179                            if len(cache) >= maxsize:
180                                root_full[1] = True
181                    return res
182
183            wrapper.__wrapped__ = function
184            return update_wrapper(wrapper, function)
185
186        return decorator
187
188
189### Module Constants
190
191# compliant with DB API 2.0
192apilevel = '2.0'
193
194# module may be shared, but not connections
195threadsafety = 1
196
197# this module use extended python format codes
198paramstyle = 'pyformat'
199
200# shortcut methods have been excluded from DB API 2 and
201# are not recommended by the DB SIG, but they can be handy
202shortcutmethods = 1
203
204
205### Internal Type Handling
206
207try:
208    from inspect import signature
209except ImportError:  # Python < 3.3
210    from inspect import getargspec
211
212    def get_args(func):
213        return getargspec(func).args
214else:
215
216    def get_args(func):
217        return list(signature(func).parameters)
218
219try:
220    from datetime import timezone
221except ImportError:  # Python < 3.2
222
223    class timezone(tzinfo):
224        """Simple timezone implementation."""
225
226        def __init__(self, offset, name=None):
227            self.offset = offset
228            if not name:
229                minutes = self.offset.days * 1440 + self.offset.seconds // 60
230                if minutes < 0:
231                    hours, minutes = divmod(-minutes, 60)
232                    hours = -hours
233                else:
234                    hours, minutes = divmod(minutes, 60)
235                name = 'UTC%+03d:%02d' % (hours, minutes)
236            self.name = name
237
238        def utcoffset(self, dt):
239            return self.offset
240
241        def tzname(self, dt):
242            return self.name
243
244        def dst(self, dt):
245            return None
246
247    timezone.utc = timezone(timedelta(0), 'UTC')
248
249    _has_timezone = False
250else:
251    _has_timezone = True
252
253# time zones used in Postgres timestamptz output
254_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
255    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
256    UCT='+0000', UTC='+0000', WET='+0000')
257
258
259def _timezone_as_offset(tz):
260    if tz.startswith(('+', '-')):
261        if len(tz) < 5:
262            return tz + '00'
263        return tz.replace(':', '')
264    return _timezones.get(tz, '+0000')
265
266
267def _get_timezone(tz):
268    tz = _timezone_as_offset(tz)
269    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
270    if tz[0] == '-':
271        minutes = -minutes
272    return timezone(timedelta(minutes=minutes), tz)
273
274
275def decimal_type(decimal_type=None):
276    """Get or set global type to be used for decimal values.
277
278    Note that connections cache cast functions. To be sure a global change
279    is picked up by a running connection, call con.type_cache.reset_typecast().
280    """
281    global Decimal
282    if decimal_type is not None:
283        Decimal = decimal_type
284        set_typecast('numeric', decimal_type)
285    return Decimal
286
287
288def cast_bool(value):
289    """Cast boolean value in database format to bool."""
290    if value:
291        return value[0] in ('t', 'T')
292
293
294def cast_money(value):
295    """Cast money value in database format to Decimal."""
296    if value:
297        value = value.replace('(', '-')
298        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
299
300
301def cast_int2vector(value):
302    """Cast an int2vector value."""
303    return [int(v) for v in value.split()]
304
305
306def cast_date(value, connection):
307    """Cast a date value."""
308    # The output format depends on the server setting DateStyle.  The default
309    # setting ISO and the setting for German are actually unambiguous.  The
310    # order of days and months in the other two settings is however ambiguous,
311    # so at least here we need to consult the setting to properly parse values.
312    if value == '-infinity':
313        return date.min
314    if value == 'infinity':
315        return date.max
316    value = value.split()
317    if value[-1] == 'BC':
318        return date.min
319    value = value[0]
320    if len(value) > 10:
321        return date.max
322    fmt = connection.date_format()
323    return datetime.strptime(value, fmt).date()
324
325
326def cast_time(value):
327    """Cast a time value."""
328    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
329    return datetime.strptime(value, fmt).time()
330
331
332_re_timezone = regex('(.*)([+-].*)')
333
334
335def cast_timetz(value):
336    """Cast a timetz value."""
337    tz = _re_timezone.match(value)
338    if tz:
339        value, tz = tz.groups()
340    else:
341        tz = '+0000'
342    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
343    if _has_timezone:
344        value += _timezone_as_offset(tz)
345        fmt += '%z'
346        return datetime.strptime(value, fmt).timetz()
347    return datetime.strptime(value, fmt).timetz().replace(
348        tzinfo=_get_timezone(tz))
349
350
351def cast_timestamp(value, connection):
352    """Cast a timestamp value."""
353    if value == '-infinity':
354        return datetime.min
355    if value == 'infinity':
356        return datetime.max
357    value = value.split()
358    if value[-1] == 'BC':
359        return datetime.min
360    fmt = connection.date_format()
361    if fmt.endswith('-%Y') and len(value) > 2:
362        value = value[1:5]
363        if len(value[3]) > 4:
364            return datetime.max
365        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
366            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
367    else:
368        if len(value[0]) > 10:
369            return datetime.max
370        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
371    return datetime.strptime(' '.join(value), ' '.join(fmt))
372
373
374def cast_timestamptz(value, connection):
375    """Cast a timestamptz value."""
376    if value == '-infinity':
377        return datetime.min
378    if value == 'infinity':
379        return datetime.max
380    value = value.split()
381    if value[-1] == 'BC':
382        return datetime.min
383    fmt = connection.date_format()
384    if fmt.endswith('-%Y') and len(value) > 2:
385        value = value[1:]
386        if len(value[3]) > 4:
387            return datetime.max
388        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
389            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
390        value, tz = value[:-1], value[-1]
391    else:
392        if fmt.startswith('%Y-'):
393            tz = _re_timezone.match(value[1])
394            if tz:
395                value[1], tz = tz.groups()
396            else:
397                tz = '+0000'
398        else:
399            value, tz = value[:-1], value[-1]
400        if len(value[0]) > 10:
401            return datetime.max
402        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
403    if _has_timezone:
404        value.append(_timezone_as_offset(tz))
405        fmt.append('%z')
406        return datetime.strptime(' '.join(value), ' '.join(fmt))
407    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
408        tzinfo=_get_timezone(tz))
409
410
411_re_interval_sql_standard = regex(
412    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
413    '(?:([+-]?[0-9]+)(?!:) ?)?'
414    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
415
416_re_interval_postgres = regex(
417    '(?:([+-]?[0-9]+) ?years? ?)?'
418    '(?:([+-]?[0-9]+) ?mons? ?)?'
419    '(?:([+-]?[0-9]+) ?days? ?)?'
420    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
421
422_re_interval_postgres_verbose = regex(
423    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
424    '(?:([+-]?[0-9]+) ?mons? ?)?'
425    '(?:([+-]?[0-9]+) ?days? ?)?'
426    '(?:([+-]?[0-9]+) ?hours? ?)?'
427    '(?:([+-]?[0-9]+) ?mins? ?)?'
428    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
429
430_re_interval_iso_8601 = regex(
431    'P(?:([+-]?[0-9]+)Y)?'
432    '(?:([+-]?[0-9]+)M)?'
433    '(?:([+-]?[0-9]+)D)?'
434    '(?:T(?:([+-]?[0-9]+)H)?'
435    '(?:([+-]?[0-9]+)M)?'
436    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
437
438
439def cast_interval(value):
440    """Cast an interval value."""
441    # The output format depends on the server setting IntervalStyle, but it's
442    # not necessary to consult this setting to parse it.  It's faster to just
443    # check all possible formats, and there is no ambiguity here.
444    m = _re_interval_iso_8601.match(value)
445    if m:
446        m = [d or '0' for d in m.groups()]
447        secs_ago = m.pop(5) == '-'
448        m = [int(d) for d in m]
449        years, mons, days, hours, mins, secs, usecs = m
450        if secs_ago:
451            secs = -secs
452            usecs = -usecs
453    else:
454        m = _re_interval_postgres_verbose.match(value)
455        if m:
456            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
457            secs_ago = m.pop(5) == '-'
458            m = [-int(d) for d in m] if ago else [int(d) for d in m]
459            years, mons, days, hours, mins, secs, usecs = m
460            if secs_ago:
461                secs = - secs
462                usecs = -usecs
463        else:
464            m = _re_interval_postgres.match(value)
465            if m and any(m.groups()):
466                m = [d or '0' for d in m.groups()]
467                hours_ago = m.pop(3) == '-'
468                m = [int(d) for d in m]
469                years, mons, days, hours, mins, secs, usecs = m
470                if hours_ago:
471                    hours = -hours
472                    mins = -mins
473                    secs = -secs
474                    usecs = -usecs
475            else:
476                m = _re_interval_sql_standard.match(value)
477                if m and any(m.groups()):
478                    m = [d or '0' for d in m.groups()]
479                    years_ago = m.pop(0) == '-'
480                    hours_ago = m.pop(3) == '-'
481                    m = [int(d) for d in m]
482                    years, mons, days, hours, mins, secs, usecs = m
483                    if years_ago:
484                        years = -years
485                        mons = -mons
486                    if hours_ago:
487                        hours = -hours
488                        mins = -mins
489                        secs = -secs
490                        usecs = -usecs
491                else:
492                    raise ValueError('Cannot parse interval: %s' % value)
493    days += 365 * years + 30 * mons
494    return timedelta(days=days, hours=hours, minutes=mins,
495        seconds=secs, microseconds=usecs)
496
497
498class Typecasts(dict):
499    """Dictionary mapping database types to typecast functions.
500
501    The cast functions get passed the string representation of a value in
502    the database which they need to convert to a Python object.  The
503    passed string will never be None since NULL values are already
504    handled before the cast function is called.
505    """
506
507    # the default cast functions
508    # (str functions are ignored but have been added for faster access)
509    defaults = {'char': str, 'bpchar': str, 'name': str,
510        'text': str, 'varchar': str,
511        'bool': cast_bool, 'bytea': unescape_bytea,
512        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
513        'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode,
514        'float4': float, 'float8': float,
515        'numeric': Decimal, 'money': cast_money,
516        'date': cast_date, 'interval': cast_interval,
517        'time': cast_time, 'timetz': cast_timetz,
518        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
519        'int2vector': cast_int2vector, 'uuid': Uuid,
520        'anyarray': cast_array, 'record': cast_record}
521
522    connection = None  # will be set in local connection specific instances
523
524    def __missing__(self, typ):
525        """Create a cast function if it is not cached.
526
527        Note that this class never raises a KeyError,
528        but returns None when no special cast function exists.
529        """
530        if not isinstance(typ, str):
531            raise TypeError('Invalid type: %s' % typ)
532        cast = self.defaults.get(typ)
533        if cast:
534            # store default for faster access
535            cast = self._add_connection(cast)
536            self[typ] = cast
537        elif typ.startswith('_'):
538            # create array cast
539            base_cast = self[typ[1:]]
540            cast = self.create_array_cast(base_cast)
541            if base_cast:
542                # store only if base type exists
543                self[typ] = cast
544        return cast
545
546    @staticmethod
547    def _needs_connection(func):
548        """Check if a typecast function needs a connection argument."""
549        try:
550            args = get_args(func)
551        except (TypeError, ValueError):
552            return False
553        else:
554            return 'connection' in args[1:]
555
556    def _add_connection(self, cast):
557        """Add a connection argument to the typecast function if necessary."""
558        if not self.connection or not self._needs_connection(cast):
559            return cast
560        return partial(cast, connection=self.connection)
561
562    def get(self, typ, default=None):
563        """Get the typecast function for the given database type."""
564        return self[typ] or default
565
566    def set(self, typ, cast):
567        """Set a typecast function for the specified database type(s)."""
568        if isinstance(typ, basestring):
569            typ = [typ]
570        if cast is None:
571            for t in typ:
572                self.pop(t, None)
573                self.pop('_%s' % t, None)
574        else:
575            if not callable(cast):
576                raise TypeError("Cast parameter must be callable")
577            for t in typ:
578                self[t] = self._add_connection(cast)
579                self.pop('_%s' % t, None)
580
581    def reset(self, typ=None):
582        """Reset the typecasts for the specified type(s) to their defaults.
583
584        When no type is specified, all typecasts will be reset.
585        """
586        defaults = self.defaults
587        if typ is None:
588            self.clear()
589            self.update(defaults)
590        else:
591            if isinstance(typ, basestring):
592                typ = [typ]
593            for t in typ:
594                cast = defaults.get(t)
595                if cast:
596                    self[t] = self._add_connection(cast)
597                    t = '_%s' % t
598                    cast = defaults.get(t)
599                    if cast:
600                        self[t] = self._add_connection(cast)
601                    else:
602                        self.pop(t, None)
603                else:
604                    self.pop(t, None)
605                    self.pop('_%s' % t, None)
606
607    def create_array_cast(self, basecast):
608        """Create an array typecast for the given base cast."""
609        cast_array = self['anyarray']
610        def cast(v):
611            return cast_array(v, basecast)
612        return cast
613
614    def create_record_cast(self, name, fields, casts):
615        """Create a named record typecast for the given fields and casts."""
616        cast_record = self['record']
617        record = namedtuple(name, fields)
618        def cast(v):
619            return record(*cast_record(v, casts))
620        return cast
621
622
623_typecasts = Typecasts()  # this is the global typecast dictionary
624
625
626def get_typecast(typ):
627    """Get the global typecast function for the given database type(s)."""
628    return _typecasts.get(typ)
629
630
631def set_typecast(typ, cast):
632    """Set a global typecast function for the given database type(s).
633
634    Note that connections cache cast functions. To be sure a global change
635    is picked up by a running connection, call con.type_cache.reset_typecast().
636    """
637    _typecasts.set(typ, cast)
638
639
640def reset_typecast(typ=None):
641    """Reset the global typecasts for the given type(s) to their default.
642
643    When no type is specified, all typecasts will be reset.
644
645    Note that connections cache cast functions. To be sure a global change
646    is picked up by a running connection, call con.type_cache.reset_typecast().
647    """
648    _typecasts.reset(typ)
649
650
651class LocalTypecasts(Typecasts):
652    """Map typecasts, including local composite types, to cast functions."""
653
654    defaults = _typecasts
655
656    connection = None  # will be set in a connection specific instance
657
658    def __missing__(self, typ):
659        """Create a cast function if it is not cached."""
660        if typ.startswith('_'):
661            base_cast = self[typ[1:]]
662            cast = self.create_array_cast(base_cast)
663            if base_cast:
664                self[typ] = cast
665        else:
666            cast = self.defaults.get(typ)
667            if cast:
668                cast = self._add_connection(cast)
669                self[typ] = cast
670            else:
671                fields = self.get_fields(typ)
672                if fields:
673                    casts = [self[field.type] for field in fields]
674                    fields = [field.name for field in fields]
675                    cast = self.create_record_cast(typ, fields, casts)
676                    self[typ] = cast
677        return cast
678
679    def get_fields(self, typ):
680        """Return the fields for the given record type.
681
682        This method will be replaced with a method that looks up the fields
683        using the type cache of the connection.
684        """
685        return []
686
687
688class TypeCode(str):
689    """Class representing the type_code used by the DB-API 2.0.
690
691    TypeCode objects are strings equal to the PostgreSQL type name,
692    but carry some additional information.
693    """
694
695    @classmethod
696    def create(cls, oid, name, len, type, category, delim, relid):
697        """Create a type code for a PostgreSQL data type."""
698        self = cls(name)
699        self.oid = oid
700        self.len = len
701        self.type = type
702        self.category = category
703        self.delim = delim
704        self.relid = relid
705        return self
706
707FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
708
709
710class TypeCache(dict):
711    """Cache for database types.
712
713    This cache maps type OIDs and names to TypeCode strings containing
714    important information on the associated database type.
715    """
716
717    def __init__(self, cnx):
718        """Initialize type cache for connection."""
719        super(TypeCache, self).__init__()
720        self._escape_string = cnx.escape_string
721        self._src = cnx.source()
722        self._typecasts = LocalTypecasts()
723        self._typecasts.get_fields = self.get_fields
724        self._typecasts.connection = cnx
725        if cnx.server_version < 80400:
726            # older remote databases (not officially supported)
727            self._query_pg_type = ("SELECT oid, typname,"
728                " typlen, typtype, null as typcategory, typdelim, typrelid"
729                " FROM pg_type WHERE oid=%s")
730        else:
731            self._query_pg_type = ("SELECT oid, typname,"
732                " typlen, typtype, typcategory, typdelim, typrelid"
733                " FROM pg_type WHERE oid=%s")
734
735    def __missing__(self, key):
736        """Get the type info from the database if it is not cached."""
737        if isinstance(key, int):
738            oid = key
739        else:
740            if '.' not in key and '"' not in key:
741                key = '"%s"' % (key,)
742            oid = "'%s'::regtype" % (self._escape_string(key),)
743        try:
744            self._src.execute(self._query_pg_type % (oid,))
745        except ProgrammingError:
746            res = None
747        else:
748            res = self._src.fetch(1)
749        if not res:
750            raise KeyError('Type %s could not be found' % (key,))
751        res = res[0]
752        type_code = TypeCode.create(int(res[0]), res[1],
753            int(res[2]), res[3], res[4], res[5], int(res[6]))
754        self[type_code.oid] = self[str(type_code)] = type_code
755        return type_code
756
757    def get(self, key, default=None):
758        """Get the type even if it is not cached."""
759        try:
760            return self[key]
761        except KeyError:
762            return default
763
764    def get_fields(self, typ):
765        """Get the names and types of the fields of composite types."""
766        if not isinstance(typ, TypeCode):
767            typ = self.get(typ)
768            if not typ:
769                return None
770        if not typ.relid:
771            return None  # this type is not composite
772        self._src.execute("SELECT attname, atttypid"
773            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
774            " AND NOT attisdropped ORDER BY attnum" % (typ.relid,))
775        return [FieldInfo(name, self.get(int(oid)))
776            for name, oid in self._src.fetch(-1)]
777
778    def get_typecast(self, typ):
779        """Get the typecast function for the given database type."""
780        return self._typecasts.get(typ)
781
782    def set_typecast(self, typ, cast):
783        """Set a typecast function for the specified database type(s)."""
784        self._typecasts.set(typ, cast)
785
786    def reset_typecast(self, typ=None):
787        """Reset the typecast function for the specified database type(s)."""
788        self._typecasts.reset(typ)
789
790    def typecast(self, value, typ):
791        """Cast the given value according to the given database type."""
792        if value is None:
793            # for NULL values, no typecast is necessary
794            return None
795        cast = self.get_typecast(typ)
796        if not cast or cast is str:
797            # no typecast is necessary
798            return value
799        return cast(value)
800
801
802class _quotedict(dict):
803    """Dictionary with auto quoting of its items.
804
805    The quote attribute must be set to the desired quote function.
806    """
807
808    def __getitem__(self, key):
809        return self.quote(super(_quotedict, self).__getitem__(key))
810
811
812### Error Messages
813
814def _db_error(msg, cls=DatabaseError):
815    """Return DatabaseError with empty sqlstate attribute."""
816    error = cls(msg)
817    error.sqlstate = None
818    return error
819
820
821def _op_error(msg):
822    """Return OperationalError."""
823    return _db_error(msg, OperationalError)
824
825
826### Row Tuples
827
828_re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$')
829
830# The result rows for database operations are returned as named tuples
831# by default. Since creating namedtuple classes is a somewhat expensive
832# operation, we cache up to 1024 of these classes by default.
833
834@lru_cache(maxsize=1024)
835def _row_factory(names):
836    """Get a namedtuple factory for row results with the given names."""
837    try:
838        try:
839            return namedtuple('Row', names, rename=True)._make
840        except TypeError:  # Python 2.6 and 3.0 do not support rename
841            names = [v if _re_fieldname.match(v) and not iskeyword(v)
842                        else 'column_%d' % (n,)
843                     for n, v in enumerate(names)]
844            return namedtuple('Row', names)._make
845    except ValueError:  # there is still a problem with the field names
846        names = ['column_%d' % (n,) for n in range(len(names))]
847        return namedtuple('Row', names)._make
848
849
850def set_row_factory_size(maxsize):
851    """Change the size of the namedtuple factory cache.
852
853    If maxsize is set to None, the cache can grow without bound.
854    """
855    global _row_factory
856    _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__)
857
858
859### Cursor Object
860
861class Cursor(object):
862    """Cursor object."""
863
864    def __init__(self, dbcnx):
865        """Create a cursor object for the database connection."""
866        self.connection = self._dbcnx = dbcnx
867        self._cnx = dbcnx._cnx
868        self.type_cache = dbcnx.type_cache
869        self._src = self._cnx.source()
870        # the official attribute for describing the result columns
871        self._description = None
872        if self.row_factory is Cursor.row_factory:
873            # the row factory needs to be determined dynamically
874            self.row_factory = None
875        else:
876            self.build_row_factory = None
877        self.rowcount = -1
878        self.arraysize = 1
879        self.lastrowid = None
880
881    def __iter__(self):
882        """Make cursor compatible to the iteration protocol."""
883        return self
884
885    def __enter__(self):
886        """Enter the runtime context for the cursor object."""
887        return self
888
889    def __exit__(self, et, ev, tb):
890        """Exit the runtime context for the cursor object."""
891        self.close()
892
893    def _quote(self, value):
894        """Quote value depending on its type."""
895        if value is None:
896            return 'NULL'
897        if isinstance(value, (Hstore, Json)):
898            value = str(value)
899        if isinstance(value, basestring):
900            if isinstance(value, Binary):
901                value = self._cnx.escape_bytea(value)
902                if bytes is not str:  # Python >= 3.0
903                    value = value.decode('ascii')
904            else:
905                value = self._cnx.escape_string(value)
906            return "'%s'" % (value,)
907        if isinstance(value, float):
908            if isinf(value):
909                return "'-Infinity'" if value < 0 else "'Infinity'"
910            if isnan(value):
911                return "'NaN'"
912            return value
913        if isinstance(value, (int, long, Decimal, Literal)):
914            return value
915        if isinstance(value, datetime):
916            if value.tzinfo:
917                return "'%s'::timestamptz" % (value,)
918            return "'%s'::timestamp" % (value,)
919        if isinstance(value, date):
920            return "'%s'::date" % (value,)
921        if isinstance(value, time):
922            if value.tzinfo:
923                return "'%s'::timetz" % (value,)
924            return "'%s'::time" % value
925        if isinstance(value, timedelta):
926            return "'%s'::interval" % (value,)
927        if isinstance(value, Uuid):
928            return "'%s'::uuid" % (value,)
929        if isinstance(value, list):
930            # Quote value as an ARRAY constructor. This is better than using
931            # an array literal because it carries the information that this is
932            # an array and not a string.  One issue with this syntax is that
933            # you need to add an explicit typecast when passing empty arrays.
934            # The ARRAY keyword is actually only necessary at the top level.
935            if not value:  # exception for empty array
936                return "'{}'"
937            q = self._quote
938            try:
939                return 'ARRAY[%s]' % (','.join(str(q(v)) for v in value),)
940            except UnicodeEncodeError:  # Python 2 with non-ascii values
941                return u'ARRAY[%s]' % (','.join(unicode(q(v)) for v in value),)
942        if isinstance(value, tuple):
943            # Quote as a ROW constructor.  This is better than using a record
944            # literal because it carries the information that this is a record
945            # and not a string.  We don't use the keyword ROW in order to make
946            # this usable with the IN syntax as well.  It is only necessary
947            # when the records has a single column which is not really useful.
948            q = self._quote
949            try:
950                return '(%s)' % (','.join(str(q(v)) for v in value),)
951            except UnicodeEncodeError:  # Python 2 with non-ascii values
952                return u'(%s)' % (','.join(unicode(q(v)) for v in value),)
953        try:
954            value = value.__pg_repr__()
955        except AttributeError:
956            raise InterfaceError(
957                'Do not know how to adapt type %s' % (type(value),))
958        if isinstance(value, (tuple, list)):
959            value = self._quote(value)
960        return value
961
962    def _quoteparams(self, string, parameters):
963        """Quote parameters.
964
965        This function works for both mappings and sequences.
966
967        The function should be used even when there are no parameters,
968        so that we have a consistent behavior regarding percent signs.
969        """
970        if not parameters:
971            try:
972                return string % ()  # unescape literal quotes if possible
973            except (TypeError, ValueError):
974                return string  # silently accept unescaped quotes
975        if isinstance(parameters, dict):
976            parameters = _quotedict(parameters)
977            parameters.quote = self._quote
978        else:
979            parameters = tuple(map(self._quote, parameters))
980        return string % parameters
981
982    def _make_description(self, info):
983        """Make the description tuple for the given field info."""
984        name, typ, size, mod = info[1:]
985        type_code = self.type_cache[typ]
986        if mod > 0:
987            mod -= 4
988        if type_code == 'numeric':
989            precision, scale = mod >> 16, mod & 0xffff
990            size = precision
991        else:
992            if not size:
993                size = type_code.size
994            if size == -1:
995                size = mod
996            precision = scale = None
997        return CursorDescription(name, type_code,
998            None, size, precision, scale, None)
999
1000    @property
1001    def description(self):
1002        """Read-only attribute describing the result columns."""
1003        descr = self._description
1004        if self._description is True:
1005            make = self._make_description
1006            descr = [make(info) for info in self._src.listinfo()]
1007            self._description = descr
1008        return descr
1009
1010    @property
1011    def colnames(self):
1012        """Unofficial convenience method for getting the column names."""
1013        return [d[0] for d in self.description]
1014
1015    @property
1016    def coltypes(self):
1017        """Unofficial convenience method for getting the column types."""
1018        return [d[1] for d in self.description]
1019
1020    def close(self):
1021        """Close the cursor object."""
1022        self._src.close()
1023
1024    def execute(self, operation, parameters=None):
1025        """Prepare and execute a database operation (query or command)."""
1026        # The parameters may also be specified as list of tuples to e.g.
1027        # insert multiple rows in a single operation, but this kind of
1028        # usage is deprecated.  We make several plausibility checks because
1029        # tuples can also be passed with the meaning of ROW constructors.
1030        if (parameters and isinstance(parameters, list)
1031                and len(parameters) > 1
1032                and all(isinstance(p, tuple) for p in parameters)
1033                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
1034            return self.executemany(operation, parameters)
1035        else:
1036            # not a list of tuples
1037            return self.executemany(operation, [parameters])
1038
1039    def executemany(self, operation, seq_of_parameters):
1040        """Prepare operation and execute it against a parameter sequence."""
1041        if not seq_of_parameters:
1042            # don't do anything without parameters
1043            return
1044        self._description = None
1045        self.rowcount = -1
1046        # first try to execute all queries
1047        rowcount = 0
1048        sql = "BEGIN"
1049        try:
1050            if not self._dbcnx._tnx and not self._dbcnx.autocommit:
1051                try:
1052                    self._src.execute(sql)
1053                except DatabaseError:
1054                    raise  # database provides error message
1055                except Exception:
1056                    raise _op_error("Can't start transaction")
1057                else:
1058                    self._dbcnx._tnx = True
1059            for parameters in seq_of_parameters:
1060                sql = operation
1061                sql = self._quoteparams(sql, parameters)
1062                rows = self._src.execute(sql)
1063                if rows:  # true if not DML
1064                    rowcount += rows
1065                else:
1066                    self.rowcount = -1
1067        except DatabaseError:
1068            raise  # database provides error message
1069        except Error as err:
1070            raise _db_error(
1071                "Error in '%s': '%s' " % (sql, err), InterfaceError)
1072        except Exception as err:
1073            raise _op_error("Internal error in '%s': %s" % (sql, err))
1074        # then initialize result raw count and description
1075        if self._src.resulttype == RESULT_DQL:
1076            self._description = True  # fetch on demand
1077            self.rowcount = self._src.ntuples
1078            self.lastrowid = None
1079            if self.build_row_factory:
1080                self.row_factory = self.build_row_factory()
1081        else:
1082            self.rowcount = rowcount
1083            self.lastrowid = self._src.oidstatus()
1084        # return the cursor object, so you can write statements such as
1085        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
1086        return self
1087
1088    def fetchone(self):
1089        """Fetch the next row of a query result set."""
1090        res = self.fetchmany(1, False)
1091        try:
1092            return res[0]
1093        except IndexError:
1094            return None
1095
1096    def fetchall(self):
1097        """Fetch all (remaining) rows of a query result."""
1098        return self.fetchmany(-1, False)
1099
1100    def fetchmany(self, size=None, keep=False):
1101        """Fetch the next set of rows of a query result.
1102
1103        The number of rows to fetch per call is specified by the
1104        size parameter. If it is not given, the cursor's arraysize
1105        determines the number of rows to be fetched. If you set
1106        the keep parameter to true, this is kept as new arraysize.
1107        """
1108        if size is None:
1109            size = self.arraysize
1110        if keep:
1111            self.arraysize = size
1112        try:
1113            result = self._src.fetch(size)
1114        except DatabaseError:
1115            raise
1116        except Error as err:
1117            raise _db_error(str(err))
1118        typecast = self.type_cache.typecast
1119        return [self.row_factory([typecast(value, typ)
1120            for typ, value in zip(self.coltypes, row)]) for row in result]
1121
1122    def callproc(self, procname, parameters=None):
1123        """Call a stored database procedure with the given name.
1124
1125        The sequence of parameters must contain one entry for each input
1126        argument that the procedure expects. The result of the call is the
1127        same as this input sequence; replacement of output and input/output
1128        parameters in the return value is currently not supported.
1129
1130        The procedure may also provide a result set as output. These can be
1131        requested through the standard fetch methods of the cursor.
1132        """
1133        n = parameters and len(parameters) or 0
1134        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
1135        self.execute(query, parameters)
1136        return parameters
1137
1138    def copy_from(self, stream, table,
1139            format=None, sep=None, null=None, size=None, columns=None):
1140        """Copy data from an input stream to the specified table.
1141
1142        The input stream can be a file-like object with a read() method or
1143        it can also be an iterable returning a row or multiple rows of input
1144        on each iteration.
1145
1146        The format must be text, csv or binary. The sep option sets the
1147        column separator (delimiter) used in the non binary formats.
1148        The null option sets the textual representation of NULL in the input.
1149
1150        The size option sets the size of the buffer used when reading data
1151        from file-like objects.
1152
1153        The copy operation can be restricted to a subset of columns. If no
1154        columns are specified, all of them will be copied.
1155        """
1156        binary_format = format == 'binary'
1157        try:
1158            read = stream.read
1159        except AttributeError:
1160            if size:
1161                raise ValueError("Size must only be set for file-like objects")
1162            if binary_format:
1163                input_type = bytes
1164                type_name = 'byte strings'
1165            else:
1166                input_type = basestring
1167                type_name = 'strings'
1168
1169            if isinstance(stream, basestring):
1170                if not isinstance(stream, input_type):
1171                    raise ValueError("The input must be %s" % (type_name,))
1172                if not binary_format:
1173                    if isinstance(stream, str):
1174                        if not stream.endswith('\n'):
1175                            stream += '\n'
1176                    else:
1177                        if not stream.endswith(b'\n'):
1178                            stream += b'\n'
1179
1180                def chunks():
1181                    yield stream
1182
1183            elif isinstance(stream, Iterable):
1184
1185                def chunks():
1186                    for chunk in stream:
1187                        if not isinstance(chunk, input_type):
1188                            raise ValueError(
1189                                "Input stream must consist of %s"
1190                                % (type_name,))
1191                        if isinstance(chunk, str):
1192                            if not chunk.endswith('\n'):
1193                                chunk += '\n'
1194                        else:
1195                            if not chunk.endswith(b'\n'):
1196                                chunk += b'\n'
1197                        yield chunk
1198
1199            else:
1200                raise TypeError("Need an input stream to copy from")
1201        else:
1202            if size is None:
1203                size = 8192
1204            elif not isinstance(size, int):
1205                raise TypeError("The size option must be an integer")
1206            if size > 0:
1207
1208                def chunks():
1209                    while True:
1210                        buffer = read(size)
1211                        yield buffer
1212                        if not buffer or len(buffer) < size:
1213                            break
1214
1215            else:
1216
1217                def chunks():
1218                    yield read()
1219
1220        if not table or not isinstance(table, basestring):
1221            raise TypeError("Need a table to copy to")
1222        if table.lower().startswith('select'):
1223            raise ValueError("Must specify a table, not a query")
1224        else:
1225            table = '"%s"' % (table,)
1226        operation = ['copy %s' % (table,)]
1227        options = []
1228        params = []
1229        if format is not None:
1230            if not isinstance(format, basestring):
1231                raise TypeError("The frmat option must be be a string")
1232            if format not in ('text', 'csv', 'binary'):
1233                raise ValueError("Invalid format")
1234            options.append('format %s' % (format,))
1235        if sep is not None:
1236            if not isinstance(sep, basestring):
1237                raise TypeError("The sep option must be a string")
1238            if format == 'binary':
1239                raise ValueError(
1240                    "The sep option is not allowed with binary format")
1241            if len(sep) != 1:
1242                raise ValueError(
1243                    "The sep option must be a single one-byte character")
1244            options.append('delimiter %s')
1245            params.append(sep)
1246        if null is not None:
1247            if not isinstance(null, basestring):
1248                raise TypeError("The null option must be a string")
1249            options.append('null %s')
1250            params.append(null)
1251        if columns:
1252            if not isinstance(columns, basestring):
1253                columns = ','.join('"%s"' % (col,) for col in columns)
1254            operation.append('(%s)' % (columns,))
1255        operation.append("from stdin")
1256        if options:
1257            operation.append('(%s)' % (','.join(options),))
1258        operation = ' '.join(operation)
1259
1260        putdata = self._src.putdata
1261        self.execute(operation, params)
1262
1263        try:
1264            for chunk in chunks():
1265                putdata(chunk)
1266        except BaseException as error:
1267            self.rowcount = -1
1268            # the following call will re-raise the error
1269            putdata(error)
1270        else:
1271            self.rowcount = putdata(None)
1272
1273        # return the cursor object, so you can chain operations
1274        return self
1275
1276    def copy_to(self, stream, table,
1277            format=None, sep=None, null=None, decode=None, columns=None):
1278        """Copy data from the specified table to an output stream.
1279
1280        The output stream can be a file-like object with a write() method or
1281        it can also be None, in which case the method will return a generator
1282        yielding a row on each iteration.
1283
1284        Output will be returned as byte strings unless you set decode to true.
1285
1286        Note that you can also use a select query instead of the table name.
1287
1288        The format must be text, csv or binary. The sep option sets the
1289        column separator (delimiter) used in the non binary formats.
1290        The null option sets the textual representation of NULL in the output.
1291
1292        The copy operation can be restricted to a subset of columns. If no
1293        columns are specified, all of them will be copied.
1294        """
1295        binary_format = format == 'binary'
1296        if stream is not None:
1297            try:
1298                write = stream.write
1299            except AttributeError:
1300                raise TypeError("Need an output stream to copy to")
1301        if not table or not isinstance(table, basestring):
1302            raise TypeError("Need a table to copy to")
1303        if table.lower().startswith('select'):
1304            if columns:
1305                raise ValueError("Columns must be specified in the query")
1306            table = '(%s)' % (table,)
1307        else:
1308            table = '"%s"' % (table,)
1309        operation = ['copy %s' % (table,)]
1310        options = []
1311        params = []
1312        if format is not None:
1313            if not isinstance(format, basestring):
1314                raise TypeError("The format option must be a string")
1315            if format not in ('text', 'csv', 'binary'):
1316                raise ValueError("Invalid format")
1317            options.append('format %s' % (format,))
1318        if sep is not None:
1319            if not isinstance(sep, basestring):
1320                raise TypeError("The sep option must be a string")
1321            if binary_format:
1322                raise ValueError(
1323                    "The sep option is not allowed with binary format")
1324            if len(sep) != 1:
1325                raise ValueError(
1326                    "The sep option must be a single one-byte character")
1327            options.append('delimiter %s')
1328            params.append(sep)
1329        if null is not None:
1330            if not isinstance(null, basestring):
1331                raise TypeError("The null option must be a string")
1332            options.append('null %s')
1333            params.append(null)
1334        if decode is None:
1335            if format == 'binary':
1336                decode = False
1337            else:
1338                decode = str is unicode
1339        else:
1340            if not isinstance(decode, (int, bool)):
1341                raise TypeError("The decode option must be a boolean")
1342            if decode and binary_format:
1343                raise ValueError(
1344                    "The decode option is not allowed with binary format")
1345        if columns:
1346            if not isinstance(columns, basestring):
1347                columns = ','.join('"%s"' % (col,) for col in columns)
1348            operation.append('(%s)' % (columns,))
1349
1350        operation.append("to stdout")
1351        if options:
1352            operation.append('(%s)' % (','.join(options),))
1353        operation = ' '.join(operation)
1354
1355        getdata = self._src.getdata
1356        self.execute(operation, params)
1357
1358        def copy():
1359            self.rowcount = 0
1360            while True:
1361                row = getdata(decode)
1362                if isinstance(row, int):
1363                    if self.rowcount != row:
1364                        self.rowcount = row
1365                    break
1366                self.rowcount += 1
1367                yield row
1368
1369        if stream is None:
1370            # no input stream, return the generator
1371            return copy()
1372
1373        # write the rows to the file-like input stream
1374        for row in copy():
1375            write(row)
1376
1377        # return the cursor object, so you can chain operations
1378        return self
1379
1380    def __next__(self):
1381        """Return the next row (support for the iteration protocol)."""
1382        res = self.fetchone()
1383        if res is None:
1384            raise StopIteration
1385        return res
1386
1387    # Note that since Python 2.6 the iterator protocol uses __next()__
1388    # instead of next(), we keep it only for backward compatibility of pgdb.
1389    next = __next__
1390
1391    @staticmethod
1392    def nextset():
1393        """Not supported."""
1394        raise NotSupportedError("The nextset() method is not supported")
1395
1396    @staticmethod
1397    def setinputsizes(sizes):
1398        """Not supported."""
1399        pass  # unsupported, but silently passed
1400
1401    @staticmethod
1402    def setoutputsize(size, column=0):
1403        """Not supported."""
1404        pass  # unsupported, but silently passed
1405
1406    @staticmethod
1407    def row_factory(row):
1408        """Process rows before they are returned.
1409
1410        You can overwrite this statically with a custom row factory, or
1411        you can build a row factory dynamically with build_row_factory().
1412
1413        For example, you can create a Cursor class that returns rows as
1414        Python dictionaries like this:
1415
1416            class DictCursor(pgdb.Cursor):
1417
1418                def row_factory(self, row):
1419                    return {desc[0]: value
1420                        for desc, value in zip(self.description, row)}
1421
1422            cur = DictCursor(con)  # get one DictCursor instance or
1423            con.cursor_type = DictCursor  # always use DictCursor instances
1424        """
1425        raise NotImplementedError
1426
1427    def build_row_factory(self):
1428        """Build a row factory based on the current description.
1429
1430        This implementation builds a row factory for creating named tuples.
1431        You can overwrite this method if you want to dynamically create
1432        different row factories whenever the column description changes.
1433        """
1434        names = self.colnames
1435        if names:
1436            return _row_factory(tuple(names))
1437
1438
1439CursorDescription = namedtuple('CursorDescription',
1440    ['name', 'type_code', 'display_size', 'internal_size',
1441     'precision', 'scale', 'null_ok'])
1442
1443
1444### Connection Objects
1445
1446class Connection(object):
1447    """Connection object."""
1448
1449    # expose the exceptions as attributes on the connection object
1450    Error = Error
1451    Warning = Warning
1452    InterfaceError = InterfaceError
1453    DatabaseError = DatabaseError
1454    InternalError = InternalError
1455    OperationalError = OperationalError
1456    ProgrammingError = ProgrammingError
1457    IntegrityError = IntegrityError
1458    DataError = DataError
1459    NotSupportedError = NotSupportedError
1460
1461    def __init__(self, cnx):
1462        """Create a database connection object."""
1463        self._cnx = cnx  # connection
1464        self._tnx = False  # transaction state
1465        self.type_cache = TypeCache(cnx)
1466        self.cursor_type = Cursor
1467        self.autocommit = False
1468        try:
1469            self._cnx.source()
1470        except Exception:
1471            raise _op_error("Invalid connection")
1472
1473    def __enter__(self):
1474        """Enter the runtime context for the connection object.
1475
1476        The runtime context can be used for running transactions.
1477
1478        This also starts a transaction in autocommit mode.
1479        """
1480        if self.autocommit:
1481            try:
1482                self._cnx.source().execute("BEGIN")
1483            except DatabaseError:
1484                raise  # database provides error message
1485            except Exception:
1486                raise _op_error("Can't start transaction")
1487            else:
1488                self._tnx = True
1489        return self
1490
1491    def __exit__(self, et, ev, tb):
1492        """Exit the runtime context for the connection object.
1493
1494        This does not close the connection, but it ends a transaction.
1495        """
1496        if et is None and ev is None and tb is None:
1497            self.commit()
1498        else:
1499            self.rollback()
1500
1501    def close(self):
1502        """Close the connection object."""
1503        if self._cnx:
1504            if self._tnx:
1505                try:
1506                    self.rollback()
1507                except DatabaseError:
1508                    pass
1509            self._cnx.close()
1510            self._cnx = None
1511        else:
1512            raise _op_error("Connection has been closed")
1513
1514    @property
1515    def closed(self):
1516        """Check whether the connection has been closed or is broken."""
1517        try:
1518            return not self._cnx or self._cnx.status != 1
1519        except TypeError:
1520            return True
1521
1522    def commit(self):
1523        """Commit any pending transaction to the database."""
1524        if self._cnx:
1525            if self._tnx:
1526                self._tnx = False
1527                try:
1528                    self._cnx.source().execute("COMMIT")
1529                except DatabaseError:
1530                    raise  # database provides error message
1531                except Exception:
1532                    raise _op_error("Can't commit transaction")
1533        else:
1534            raise _op_error("Connection has been closed")
1535
1536    def rollback(self):
1537        """Roll back to the start of any pending transaction."""
1538        if self._cnx:
1539            if self._tnx:
1540                self._tnx = False
1541                try:
1542                    self._cnx.source().execute("ROLLBACK")
1543                except DatabaseError:
1544                    raise  # database provides error message
1545                except Exception:
1546                    raise _op_error("Can't rollback transaction")
1547        else:
1548            raise _op_error("Connection has been closed")
1549
1550    def cursor(self):
1551        """Return a new cursor object using the connection."""
1552        if self._cnx:
1553            try:
1554                return self.cursor_type(self)
1555            except Exception:
1556                raise _op_error("Invalid connection")
1557        else:
1558            raise _op_error("Connection has been closed")
1559
1560    if shortcutmethods:  # otherwise do not implement and document this
1561
1562        def execute(self, operation, params=None):
1563            """Shortcut method to run an operation on an implicit cursor."""
1564            cursor = self.cursor()
1565            cursor.execute(operation, params)
1566            return cursor
1567
1568        def executemany(self, operation, param_seq):
1569            """Shortcut method to run an operation against a sequence."""
1570            cursor = self.cursor()
1571            cursor.executemany(operation, param_seq)
1572            return cursor
1573
1574
1575### Module Interface
1576
1577_connect = connect
1578
1579def connect(dsn=None,
1580        user=None, password=None,
1581        host=None, database=None, **kwargs):
1582    """Connect to a database."""
1583    # first get params from DSN
1584    dbport = -1
1585    dbhost = ""
1586    dbname = ""
1587    dbuser = ""
1588    dbpasswd = ""
1589    dbopt = ""
1590    try:
1591        params = dsn.split(":")
1592        dbhost = params[0]
1593        dbname = params[1]
1594        dbuser = params[2]
1595        dbpasswd = params[3]
1596        dbopt = params[4]
1597    except (AttributeError, IndexError, TypeError):
1598        pass
1599
1600    # override if necessary
1601    if user is not None:
1602        dbuser = user
1603    if password is not None:
1604        dbpasswd = password
1605    if database is not None:
1606        dbname = database
1607    if host is not None:
1608        try:
1609            params = host.split(":")
1610            dbhost = params[0]
1611            dbport = int(params[1])
1612        except (AttributeError, IndexError, TypeError, ValueError):
1613            pass
1614
1615    # empty host is localhost
1616    if dbhost == "":
1617        dbhost = None
1618    if dbuser == "":
1619        dbuser = None
1620
1621    # pass keyword arguments as connection info string
1622    if kwargs:
1623        kwargs = list(kwargs.items())
1624        if '=' in dbname:
1625            dbname = [dbname]
1626        else:
1627            kwargs.insert(0, ('dbname', dbname))
1628            dbname = []
1629        for kw, value in kwargs:
1630            value = str(value)
1631            if not value or ' ' in value:
1632                value = "'%s'" % (value.replace(
1633                    "'", "\\'").replace('\\', '\\\\'),)
1634            dbname.append('%s=%s' % (kw, value))
1635        dbname = ' '.join(dbname)
1636
1637    # open the connection
1638    cnx = _connect(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd)
1639    return Connection(cnx)
1640
1641
1642### Types Handling
1643
1644class Type(frozenset):
1645    """Type class for a couple of PostgreSQL data types.
1646
1647    PostgreSQL is object-oriented: types are dynamic.
1648    We must thus use type names as internal type codes.
1649    """
1650
1651    def __new__(cls, values):
1652        if isinstance(values, basestring):
1653            values = values.split()
1654        return super(Type, cls).__new__(cls, values)
1655
1656    def __eq__(self, other):
1657        if isinstance(other, basestring):
1658            if other.startswith('_'):
1659                other = other[1:]
1660            return other in self
1661        else:
1662            return super(Type, self).__eq__(other)
1663
1664    def __ne__(self, other):
1665        if isinstance(other, basestring):
1666            if other.startswith('_'):
1667                other = other[1:]
1668            return other not in self
1669        else:
1670            return super(Type, self).__ne__(other)
1671
1672
1673class ArrayType:
1674    """Type class for PostgreSQL array types."""
1675
1676    def __eq__(self, other):
1677        if isinstance(other, basestring):
1678            return other.startswith('_')
1679        else:
1680            return isinstance(other, ArrayType)
1681
1682    def __ne__(self, other):
1683        if isinstance(other, basestring):
1684            return not other.startswith('_')
1685        else:
1686            return not isinstance(other, ArrayType)
1687
1688
1689class RecordType:
1690    """Type class for PostgreSQL record types."""
1691
1692    def __eq__(self, other):
1693        if isinstance(other, TypeCode):
1694            return other.type == 'c'
1695        elif isinstance(other, basestring):
1696            return other == 'record'
1697        else:
1698            return isinstance(other, RecordType)
1699
1700    def __ne__(self, other):
1701        if isinstance(other, TypeCode):
1702            return other.type != 'c'
1703        elif isinstance(other, basestring):
1704            return other != 'record'
1705        else:
1706            return not isinstance(other, RecordType)
1707
1708
1709# Mandatory type objects defined by DB-API 2 specs:
1710
1711STRING = Type('char bpchar name text varchar')
1712BINARY = Type('bytea')
1713NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1714DATETIME = Type('date time timetz timestamp timestamptz interval'
1715    ' abstime reltime')  # these are very old
1716ROWID = Type('oid')
1717
1718
1719# Additional type objects (more specific):
1720
1721BOOL = Type('bool')
1722SMALLINT = Type('int2')
1723INTEGER = Type('int2 int4 int8 serial')
1724LONG = Type('int8')
1725FLOAT = Type('float4 float8')
1726NUMERIC = Type('numeric')
1727MONEY = Type('money')
1728DATE = Type('date')
1729TIME = Type('time timetz')
1730TIMESTAMP = Type('timestamp timestamptz')
1731INTERVAL = Type('interval')
1732UUID = Type('uuid')
1733HSTORE = Type('hstore')
1734JSON = Type('json jsonb')
1735
1736# Type object for arrays (also equate to their base types):
1737
1738ARRAY = ArrayType()
1739
1740# Type object for records (encompassing all composite types):
1741
1742RECORD = RecordType()
1743
1744
1745# Mandatory type helpers defined by DB-API 2 specs:
1746
1747def Date(year, month, day):
1748    """Construct an object holding a date value."""
1749    return date(year, month, day)
1750
1751
1752def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None):
1753    """Construct an object holding a time value."""
1754    return time(hour, minute, second, microsecond, tzinfo)
1755
1756
1757def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0,
1758        tzinfo=None):
1759    """Construct an object holding a time stamp value."""
1760    return datetime(year, month, day, hour, minute, second, microsecond, tzinfo)
1761
1762
1763def DateFromTicks(ticks):
1764    """Construct an object holding a date value from the given ticks value."""
1765    return Date(*localtime(ticks)[:3])
1766
1767
1768def TimeFromTicks(ticks):
1769    """Construct an object holding a time value from the given ticks value."""
1770    return Time(*localtime(ticks)[3:6])
1771
1772
1773def TimestampFromTicks(ticks):
1774    """Construct an object holding a time stamp from the given ticks value."""
1775    return Timestamp(*localtime(ticks)[:6])
1776
1777
1778class Binary(bytes):
1779    """Construct an object capable of holding a binary (long) string value."""
1780
1781
1782# Additional type helpers for PyGreSQL:
1783
1784def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0):
1785    """Construct an object holding a time inverval value."""
1786    return timedelta(days, hours=hours, minutes=minutes, seconds=seconds,
1787        microseconds=microseconds)
1788
1789
1790Uuid = Uuid  # Construct an object holding a UUID value
1791
1792
1793class Hstore(dict):
1794    """Wrapper class for marking hstore values."""
1795
1796    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
1797    _re_escape = regex(r'(["\\])')
1798
1799    @classmethod
1800    def _quote(cls, s):
1801        if s is None:
1802            return 'NULL'
1803        if not s:
1804            return '""'
1805        quote = cls._re_quote.search(s)
1806        s = cls._re_escape.sub(r'\\\1', s)
1807        if quote:
1808            s = '"%s"' % (s,)
1809        return s
1810
1811    def __str__(self):
1812        q = self._quote
1813        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
1814
1815
1816class Json:
1817    """Construct a wrapper for holding an object serializable to JSON."""
1818
1819    def __init__(self, obj, encode=None):
1820        self.obj = obj
1821        self.encode = encode or jsonencode
1822
1823    def __str__(self):
1824        obj = self.obj
1825        if isinstance(obj, basestring):
1826            return obj
1827        return self.encode(obj)
1828
1829
1830class Literal:
1831    """Construct a wrapper for holding a literal SQL string."""
1832
1833    def __init__(self, sql):
1834        self.sql = sql
1835
1836    def __str__(self):
1837        return self.sql
1838
1839    __pg_repr__ = __str__
1840
1841
1842# If run as script, print some information:
1843
1844if __name__ == '__main__':
1845    print('PyGreSQL version', version)
1846    print('')
1847    print(__doc__)
Note: See TracBrowser for help on using the repository browser.