source: trunk/pgdb.py

Last change on this file was 1027, checked in by cito, 3 weeks ago

Consistent spelling of PyGreSQL

  • Property svn:keywords set to Id
File size: 60.9 KB
Line 
1#!/usr/bin/python
2#
3# $Id: pgdb.py 1027 2019-10-05 15:30:59Z cito $
4#
5# PyGreSQL - a Python interface for the PostgreSQL database.
6#
7# This file contains the DB-API 2 compatible pgdb module.
8#
9# Copyright (c) 2019 by the PyGreSQL Development Team
10#
11# Please see the LICENSE.TXT file for specific restrictions.
12
13"""pgdb - DB-API 2.0 compliant module for PyGreSQL.
14
15(c) 1999, Pascal Andre <andre@via.ecp.fr>.
16See package documentation for further information on copyright.
17
18Inline documentation is sparse.
19See DB-API 2.0 specification for usage information:
20http://www.python.org/peps/pep-0249.html
21
22Basic usage:
23
24    pgdb.connect(connect_string) # open a connection
25    # connect_string = 'host:database:user:password:opt'
26    # All parts are optional. You may also pass host through
27    # password as keyword arguments. To pass a port,
28    # pass it in the host keyword parameter:
29    connection = pgdb.connect(host='localhost:5432')
30
31    cursor = connection.cursor() # open a cursor
32
33    cursor.execute(query[, params])
34    # Execute a query, binding params (a dictionary) if they are
35    # passed. The binding syntax is the same as the % operator
36    # for dictionaries, and no quoting is done.
37
38    cursor.executemany(query, list of params)
39    # Execute a query many times, binding each param dictionary
40    # from the list.
41
42    cursor.fetchone() # fetch one row, [value, value, ...]
43
44    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
45
46    cursor.fetchmany([size])
47    # returns size or cursor.arraysize number of rows,
48    # [[value, value, ...], ...] from result set.
49    # Default cursor.arraysize is 1.
50
51    cursor.description # returns information about the columns
52    #   [(column_name, type_name, display_size,
53    #           internal_size, precision, scale, null_ok), ...]
54    # Note that display_size, precision, scale and null_ok
55    # are not implemented.
56
57    cursor.rowcount # number of rows available in the result set
58    # Available after a call to execute.
59
60    connection.commit() # commit transaction
61
62    connection.rollback() # or rollback transaction
63
64    cursor.close() # close the cursor
65
66    connection.close() # close the connection
67"""
68
69from __future__ import print_function, division
70
71from _pg import *
72
73__version__ = version
74
75from datetime import date, time, datetime, timedelta, tzinfo
76from time import localtime
77from decimal import Decimal
78from uuid import UUID as Uuid
79from math import isnan, isinf
80try:
81    from collections.abc import Iterable
82except ImportError:  # Python < 3.3
83    from collections import Iterable
84from collections import namedtuple
85from keyword import iskeyword
86from functools import partial
87from re import compile as regex
88from json import loads as jsondecode, dumps as jsonencode
89
90try:  # noinspection PyUnresolvedReferences
91    long
92except NameError:  # Python >= 3.0
93    long = int
94
95try:  # noinspection PyUnresolvedReferences
96    unicode
97except NameError:  # Python >= 3.0
98    unicode = str
99
100try:  # noinspection PyUnresolvedReferences
101    basestring
102except NameError:  # Python >= 3.0
103    basestring = (str, bytes)
104
105try:
106    from functools import lru_cache
107except ImportError:  # Python < 3.2
108    from functools import update_wrapper
109    try:
110        from _thread import RLock
111    except ImportError:
112        class RLock:  # for builds without threads
113            def __enter__(self): pass
114
115            def __exit__(self, exctype, excinst, exctb): pass
116
117    def lru_cache(maxsize=128):
118        """Simplified functools.lru_cache decorator for one argument."""
119
120        def decorator(function):
121            sentinel = object()
122            cache = {}
123            get = cache.get
124            lock = RLock()
125            root = []
126            root_full = [root, False]
127            root[:] = [root, root, None, None]
128
129            if maxsize == 0:
130
131                def wrapper(arg):
132                    res = function(arg)
133                    return res
134
135            elif maxsize is None:
136
137                def wrapper(arg):
138                    res = get(arg, sentinel)
139                    if res is not sentinel:
140                        return res
141                    res = function(arg)
142                    cache[arg] = res
143                    return res
144
145            else:
146
147                def wrapper(arg):
148                    with lock:
149                        link = get(arg)
150                        if link is not None:
151                            root = root_full[0]
152                            prev, next, _arg, res = link
153                            prev[1] = next
154                            next[0] = prev
155                            last = root[0]
156                            last[1] = root[0] = link
157                            link[0] = last
158                            link[1] = root
159                            return res
160                    res = function(arg)
161                    with lock:
162                        root, full = root_full
163                        if arg in cache:
164                            pass
165                        elif full:
166                            oldroot = root
167                            oldroot[2] = arg
168                            oldroot[3] = res
169                            root = root_full[0] = oldroot[1]
170                            oldarg = root[2]
171                            oldres = root[3]  # keep reference
172                            root[2] = root[3] = None
173                            del cache[oldarg]
174                            cache[arg] = oldroot
175                        else:
176                            last = root[0]
177                            link = [last, root, arg, res]
178                            last[1] = root[0] = cache[arg] = link
179                            if len(cache) >= maxsize:
180                                root_full[1] = True
181                    return res
182
183            wrapper.__wrapped__ = function
184            return update_wrapper(wrapper, function)
185
186        return decorator
187
188
189### Module Constants
190
191# compliant with DB API 2.0
192apilevel = '2.0'
193
194# module may be shared, but not connections
195threadsafety = 1
196
197# this module use extended python format codes
198paramstyle = 'pyformat'
199
200# shortcut methods have been excluded from DB API 2 and
201# are not recommended by the DB SIG, but they can be handy
202shortcutmethods = 1
203
204
205### Internal Type Handling
206
207try:
208    from inspect import signature
209except ImportError:  # Python < 3.3
210    from inspect import getargspec
211
212    def get_args(func):
213        return getargspec(func).args
214else:
215
216    def get_args(func):
217        return list(signature(func).parameters)
218
219try:
220    from datetime import timezone
221except ImportError:  # Python < 3.2
222
223    class timezone(tzinfo):
224        """Simple timezone implementation."""
225
226        def __init__(self, offset, name=None):
227            self.offset = offset
228            if not name:
229                minutes = self.offset.days * 1440 + self.offset.seconds // 60
230                if minutes < 0:
231                    hours, minutes = divmod(-minutes, 60)
232                    hours = -hours
233                else:
234                    hours, minutes = divmod(minutes, 60)
235                name = 'UTC%+03d:%02d' % (hours, minutes)
236            self.name = name
237
238        def utcoffset(self, dt):
239            return self.offset
240
241        def tzname(self, dt):
242            return self.name
243
244        def dst(self, dt):
245            return None
246
247    timezone.utc = timezone(timedelta(0), 'UTC')
248
249    _has_timezone = False
250else:
251    _has_timezone = True
252
253# time zones used in Postgres timestamptz output
254_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
255    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
256    UCT='+0000', UTC='+0000', WET='+0000')
257
258
259def _timezone_as_offset(tz):
260    if tz.startswith(('+', '-')):
261        if len(tz) < 5:
262            return tz + '00'
263        return tz.replace(':', '')
264    return _timezones.get(tz, '+0000')
265
266
267def _get_timezone(tz):
268    tz = _timezone_as_offset(tz)
269    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
270    if tz[0] == '-':
271        minutes = -minutes
272    return timezone(timedelta(minutes=minutes), tz)
273
274
275def decimal_type(decimal_type=None):
276    """Get or set global type to be used for decimal values.
277
278    Note that connections cache cast functions. To be sure a global change
279    is picked up by a running connection, call con.type_cache.reset_typecast().
280    """
281    global Decimal
282    if decimal_type is not None:
283        Decimal = decimal_type
284        set_typecast('numeric', decimal_type)
285    return Decimal
286
287
288def cast_bool(value):
289    """Cast boolean value in database format to bool."""
290    if value:
291        return value[0] in ('t', 'T')
292
293
294def cast_money(value):
295    """Cast money value in database format to Decimal."""
296    if value:
297        value = value.replace('(', '-')
298        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
299
300
301def cast_int2vector(value):
302    """Cast an int2vector value."""
303    return [int(v) for v in value.split()]
304
305
306def cast_date(value, connection):
307    """Cast a date value."""
308    # The output format depends on the server setting DateStyle.  The default
309    # setting ISO and the setting for German are actually unambiguous.  The
310    # order of days and months in the other two settings is however ambiguous,
311    # so at least here we need to consult the setting to properly parse values.
312    if value == '-infinity':
313        return date.min
314    if value == 'infinity':
315        return date.max
316    value = value.split()
317    if value[-1] == 'BC':
318        return date.min
319    value = value[0]
320    if len(value) > 10:
321        return date.max
322    fmt = connection.date_format()
323    return datetime.strptime(value, fmt).date()
324
325
326def cast_time(value):
327    """Cast a time value."""
328    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
329    return datetime.strptime(value, fmt).time()
330
331
332_re_timezone = regex('(.*)([+-].*)')
333
334
335def cast_timetz(value):
336    """Cast a timetz value."""
337    tz = _re_timezone.match(value)
338    if tz:
339        value, tz = tz.groups()
340    else:
341        tz = '+0000'
342    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
343    if _has_timezone:
344        value += _timezone_as_offset(tz)
345        fmt += '%z'
346        return datetime.strptime(value, fmt).timetz()
347    return datetime.strptime(value, fmt).timetz().replace(
348        tzinfo=_get_timezone(tz))
349
350
351def cast_timestamp(value, connection):
352    """Cast a timestamp value."""
353    if value == '-infinity':
354        return datetime.min
355    if value == 'infinity':
356        return datetime.max
357    value = value.split()
358    if value[-1] == 'BC':
359        return datetime.min
360    fmt = connection.date_format()
361    if fmt.endswith('-%Y') and len(value) > 2:
362        value = value[1:5]
363        if len(value[3]) > 4:
364            return datetime.max
365        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
366            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
367    else:
368        if len(value[0]) > 10:
369            return datetime.max
370        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
371    return datetime.strptime(' '.join(value), ' '.join(fmt))
372
373
374def cast_timestamptz(value, connection):
375    """Cast a timestamptz value."""
376    if value == '-infinity':
377        return datetime.min
378    if value == 'infinity':
379        return datetime.max
380    value = value.split()
381    if value[-1] == 'BC':
382        return datetime.min
383    fmt = connection.date_format()
384    if fmt.endswith('-%Y') and len(value) > 2:
385        value = value[1:]
386        if len(value[3]) > 4:
387            return datetime.max
388        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
389            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
390        value, tz = value[:-1], value[-1]
391    else:
392        if fmt.startswith('%Y-'):
393            tz = _re_timezone.match(value[1])
394            if tz:
395                value[1], tz = tz.groups()
396            else:
397                tz = '+0000'
398        else:
399            value, tz = value[:-1], value[-1]
400        if len(value[0]) > 10:
401            return datetime.max
402        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
403    if _has_timezone:
404        value.append(_timezone_as_offset(tz))
405        fmt.append('%z')
406        return datetime.strptime(' '.join(value), ' '.join(fmt))
407    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
408        tzinfo=_get_timezone(tz))
409
410
411_re_interval_sql_standard = regex(
412    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
413    '(?:([+-]?[0-9]+)(?!:) ?)?'
414    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
415
416_re_interval_postgres = regex(
417    '(?:([+-]?[0-9]+) ?years? ?)?'
418    '(?:([+-]?[0-9]+) ?mons? ?)?'
419    '(?:([+-]?[0-9]+) ?days? ?)?'
420    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
421
422_re_interval_postgres_verbose = regex(
423    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
424    '(?:([+-]?[0-9]+) ?mons? ?)?'
425    '(?:([+-]?[0-9]+) ?days? ?)?'
426    '(?:([+-]?[0-9]+) ?hours? ?)?'
427    '(?:([+-]?[0-9]+) ?mins? ?)?'
428    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
429
430_re_interval_iso_8601 = regex(
431    'P(?:([+-]?[0-9]+)Y)?'
432    '(?:([+-]?[0-9]+)M)?'
433    '(?:([+-]?[0-9]+)D)?'
434    '(?:T(?:([+-]?[0-9]+)H)?'
435    '(?:([+-]?[0-9]+)M)?'
436    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
437
438
439def cast_interval(value):
440    """Cast an interval value."""
441    # The output format depends on the server setting IntervalStyle, but it's
442    # not necessary to consult this setting to parse it.  It's faster to just
443    # check all possible formats, and there is no ambiguity here.
444    m = _re_interval_iso_8601.match(value)
445    if m:
446        m = [d or '0' for d in m.groups()]
447        secs_ago = m.pop(5) == '-'
448        m = [int(d) for d in m]
449        years, mons, days, hours, mins, secs, usecs = m
450        if secs_ago:
451            secs = -secs
452            usecs = -usecs
453    else:
454        m = _re_interval_postgres_verbose.match(value)
455        if m:
456            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
457            secs_ago = m.pop(5) == '-'
458            m = [-int(d) for d in m] if ago else [int(d) for d in m]
459            years, mons, days, hours, mins, secs, usecs = m
460            if secs_ago:
461                secs = - secs
462                usecs = -usecs
463        else:
464            m = _re_interval_postgres.match(value)
465            if m and any(m.groups()):
466                m = [d or '0' for d in m.groups()]
467                hours_ago = m.pop(3) == '-'
468                m = [int(d) for d in m]
469                years, mons, days, hours, mins, secs, usecs = m
470                if hours_ago:
471                    hours = -hours
472                    mins = -mins
473                    secs = -secs
474                    usecs = -usecs
475            else:
476                m = _re_interval_sql_standard.match(value)
477                if m and any(m.groups()):
478                    m = [d or '0' for d in m.groups()]
479                    years_ago = m.pop(0) == '-'
480                    hours_ago = m.pop(3) == '-'
481                    m = [int(d) for d in m]
482                    years, mons, days, hours, mins, secs, usecs = m
483                    if years_ago:
484                        years = -years
485                        mons = -mons
486                    if hours_ago:
487                        hours = -hours
488                        mins = -mins
489                        secs = -secs
490                        usecs = -usecs
491                else:
492                    raise ValueError('Cannot parse interval: %s' % value)
493    days += 365 * years + 30 * mons
494    return timedelta(days=days, hours=hours, minutes=mins,
495        seconds=secs, microseconds=usecs)
496
497
498class Typecasts(dict):
499    """Dictionary mapping database types to typecast functions.
500
501    The cast functions get passed the string representation of a value in
502    the database which they need to convert to a Python object.  The
503    passed string will never be None since NULL values are already
504    handled before the cast function is called.
505    """
506
507    # the default cast functions
508    # (str functions are ignored but have been added for faster access)
509    defaults = {'char': str, 'bpchar': str, 'name': str,
510        'text': str, 'varchar': str,
511        'bool': cast_bool, 'bytea': unescape_bytea,
512        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
513        'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode,
514        'float4': float, 'float8': float,
515        'numeric': Decimal, 'money': cast_money,
516        'date': cast_date, 'interval': cast_interval,
517        'time': cast_time, 'timetz': cast_timetz,
518        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
519        'int2vector': cast_int2vector, 'uuid': Uuid,
520        'anyarray': cast_array, 'record': cast_record}
521
522    connection = None  # will be set in local connection specific instances
523
524    def __missing__(self, typ):
525        """Create a cast function if it is not cached.
526
527        Note that this class never raises a KeyError,
528        but returns None when no special cast function exists.
529        """
530        if not isinstance(typ, str):
531            raise TypeError('Invalid type: %s' % typ)
532        cast = self.defaults.get(typ)
533        if cast:
534            # store default for faster access
535            cast = self._add_connection(cast)
536            self[typ] = cast
537        elif typ.startswith('_'):
538            # create array cast
539            base_cast = self[typ[1:]]
540            cast = self.create_array_cast(base_cast)
541            if base_cast:
542                # store only if base type exists
543                self[typ] = cast
544        return cast
545
546    @staticmethod
547    def _needs_connection(func):
548        """Check if a typecast function needs a connection argument."""
549        try:
550            args = get_args(func)
551        except (TypeError, ValueError):
552            return False
553        else:
554            return 'connection' in args[1:]
555
556    def _add_connection(self, cast):
557        """Add a connection argument to the typecast function if necessary."""
558        if not self.connection or not self._needs_connection(cast):
559            return cast
560        return partial(cast, connection=self.connection)
561
562    def get(self, typ, default=None):
563        """Get the typecast function for the given database type."""
564        return self[typ] or default
565
566    def set(self, typ, cast):
567        """Set a typecast function for the specified database type(s)."""
568        if isinstance(typ, basestring):
569            typ = [typ]
570        if cast is None:
571            for t in typ:
572                self.pop(t, None)
573                self.pop('_%s' % t, None)
574        else:
575            if not callable(cast):
576                raise TypeError("Cast parameter must be callable")
577            for t in typ:
578                self[t] = self._add_connection(cast)
579                self.pop('_%s' % t, None)
580
581    def reset(self, typ=None):
582        """Reset the typecasts for the specified type(s) to their defaults.
583
584        When no type is specified, all typecasts will be reset.
585        """
586        defaults = self.defaults
587        if typ is None:
588            self.clear()
589            self.update(defaults)
590        else:
591            if isinstance(typ, basestring):
592                typ = [typ]
593            for t in typ:
594                cast = defaults.get(t)
595                if cast:
596                    self[t] = self._add_connection(cast)
597                    t = '_%s' % t
598                    cast = defaults.get(t)
599                    if cast:
600                        self[t] = self._add_connection(cast)
601                    else:
602                        self.pop(t, None)
603                else:
604                    self.pop(t, None)
605                    self.pop('_%s' % t, None)
606
607    def create_array_cast(self, basecast):
608        """Create an array typecast for the given base cast."""
609        cast_array = self['anyarray']
610        def cast(v):
611            return cast_array(v, basecast)
612        return cast
613
614    def create_record_cast(self, name, fields, casts):
615        """Create a named record typecast for the given fields and casts."""
616        cast_record = self['record']
617        record = namedtuple(name, fields)
618        def cast(v):
619            return record(*cast_record(v, casts))
620        return cast
621
622
623_typecasts = Typecasts()  # this is the global typecast dictionary
624
625
626def get_typecast(typ):
627    """Get the global typecast function for the given database type(s)."""
628    return _typecasts.get(typ)
629
630
631def set_typecast(typ, cast):
632    """Set a global typecast function for the given database type(s).
633
634    Note that connections cache cast functions. To be sure a global change
635    is picked up by a running connection, call con.type_cache.reset_typecast().
636    """
637    _typecasts.set(typ, cast)
638
639
640def reset_typecast(typ=None):
641    """Reset the global typecasts for the given type(s) to their default.
642
643    When no type is specified, all typecasts will be reset.
644
645    Note that connections cache cast functions. To be sure a global change
646    is picked up by a running connection, call con.type_cache.reset_typecast().
647    """
648    _typecasts.reset(typ)
649
650
651class LocalTypecasts(Typecasts):
652    """Map typecasts, including local composite types, to cast functions."""
653
654    defaults = _typecasts
655
656    connection = None  # will be set in a connection specific instance
657
658    def __missing__(self, typ):
659        """Create a cast function if it is not cached."""
660        if typ.startswith('_'):
661            base_cast = self[typ[1:]]
662            cast = self.create_array_cast(base_cast)
663            if base_cast:
664                self[typ] = cast
665        else:
666            cast = self.defaults.get(typ)
667            if cast:
668                cast = self._add_connection(cast)
669                self[typ] = cast
670            else:
671                fields = self.get_fields(typ)
672                if fields:
673                    casts = [self[field.type] for field in fields]
674                    fields = [field.name for field in fields]
675                    cast = self.create_record_cast(typ, fields, casts)
676                    self[typ] = cast
677        return cast
678
679    def get_fields(self, typ):
680        """Return the fields for the given record type.
681
682        This method will be replaced with a method that looks up the fields
683        using the type cache of the connection.
684        """
685        return []
686
687
688class TypeCode(str):
689    """Class representing the type_code used by the DB-API 2.0.
690
691    TypeCode objects are strings equal to the PostgreSQL type name,
692    but carry some additional information.
693    """
694
695    @classmethod
696    def create(cls, oid, name, len, type, category, delim, relid):
697        """Create a type code for a PostgreSQL data type."""
698        self = cls(name)
699        self.oid = oid
700        self.len = len
701        self.type = type
702        self.category = category
703        self.delim = delim
704        self.relid = relid
705        return self
706
707FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
708
709
710class TypeCache(dict):
711    """Cache for database types.
712
713    This cache maps type OIDs and names to TypeCode strings containing
714    important information on the associated database type.
715    """
716
717    def __init__(self, cnx):
718        """Initialize type cache for connection."""
719        super(TypeCache, self).__init__()
720        self._escape_string = cnx.escape_string
721        self._src = cnx.source()
722        self._typecasts = LocalTypecasts()
723        self._typecasts.get_fields = self.get_fields
724        self._typecasts.connection = cnx
725        if cnx.server_version < 80400:
726            # older remote databases (not officially supported)
727            self._query_pg_type = ("SELECT oid, typname,"
728                " typlen, typtype, null as typcategory, typdelim, typrelid"
729                " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) %s")
730        else:
731            self._query_pg_type = ("SELECT oid, typname,"
732                " typlen, typtype, typcategory, typdelim, typrelid"
733                " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) %s")
734
735    def __missing__(self, key):
736        """Get the type info from the database if it is not cached."""
737        if isinstance(key, int):
738            oid = key
739        else:
740            if '.' not in key and '"' not in key:
741                key = '"%s"' % (key,)
742            oid = "'%s'::regtype" % (self._escape_string(key),)
743        try:
744            self._src.execute(self._query_pg_type % (oid,))
745        except ProgrammingError:
746            res = None
747        else:
748            res = self._src.fetch(1)
749        if not res:
750            raise KeyError('Type %s could not be found' % (key,))
751        res = res[0]
752        type_code = TypeCode.create(int(res[0]), res[1],
753            int(res[2]), res[3], res[4], res[5], int(res[6]))
754        self[type_code.oid] = self[str(type_code)] = type_code
755        return type_code
756
757    def get(self, key, default=None):
758        """Get the type even if it is not cached."""
759        try:
760            return self[key]
761        except KeyError:
762            return default
763
764    def get_fields(self, typ):
765        """Get the names and types of the fields of composite types."""
766        if not isinstance(typ, TypeCode):
767            typ = self.get(typ)
768            if not typ:
769                return None
770        if not typ.relid:
771            return None  # this type is not composite
772        self._src.execute("SELECT attname, atttypid"
773            " FROM pg_catalog.pg_attribute"
774            " WHERE attrelid OPERATOR(pg_catalog.=) %s"
775            " AND attnum OPERATOR(pg_catalog.>) 0"
776            " AND NOT attisdropped ORDER BY attnum" % (typ.relid,))
777        return [FieldInfo(name, self.get(int(oid)))
778            for name, oid in self._src.fetch(-1)]
779
780    def get_typecast(self, typ):
781        """Get the typecast function for the given database type."""
782        return self._typecasts.get(typ)
783
784    def set_typecast(self, typ, cast):
785        """Set a typecast function for the specified database type(s)."""
786        self._typecasts.set(typ, cast)
787
788    def reset_typecast(self, typ=None):
789        """Reset the typecast function for the specified database type(s)."""
790        self._typecasts.reset(typ)
791
792    def typecast(self, value, typ):
793        """Cast the given value according to the given database type."""
794        if value is None:
795            # for NULL values, no typecast is necessary
796            return None
797        cast = self.get_typecast(typ)
798        if not cast or cast is str:
799            # no typecast is necessary
800            return value
801        return cast(value)
802
803
804class _quotedict(dict):
805    """Dictionary with auto quoting of its items.
806
807    The quote attribute must be set to the desired quote function.
808    """
809
810    def __getitem__(self, key):
811        return self.quote(super(_quotedict, self).__getitem__(key))
812
813
814### Error Messages
815
816def _db_error(msg, cls=DatabaseError):
817    """Return DatabaseError with empty sqlstate attribute."""
818    error = cls(msg)
819    error.sqlstate = None
820    return error
821
822
823def _op_error(msg):
824    """Return OperationalError."""
825    return _db_error(msg, OperationalError)
826
827
828### Row Tuples
829
830_re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$')
831
832# The result rows for database operations are returned as named tuples
833# by default. Since creating namedtuple classes is a somewhat expensive
834# operation, we cache up to 1024 of these classes by default.
835
836@lru_cache(maxsize=1024)
837def _row_factory(names):
838    """Get a namedtuple factory for row results with the given names."""
839    try:
840        try:
841            return namedtuple('Row', names, rename=True)._make
842        except TypeError:  # Python 2.6 and 3.0 do not support rename
843            names = [v if _re_fieldname.match(v) and not iskeyword(v)
844                        else 'column_%d' % (n,)
845                     for n, v in enumerate(names)]
846            return namedtuple('Row', names)._make
847    except ValueError:  # there is still a problem with the field names
848        names = ['column_%d' % (n,) for n in range(len(names))]
849        return namedtuple('Row', names)._make
850
851
852def set_row_factory_size(maxsize):
853    """Change the size of the namedtuple factory cache.
854
855    If maxsize is set to None, the cache can grow without bound.
856    """
857    global _row_factory
858    _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__)
859
860
861### Cursor Object
862
863class Cursor(object):
864    """Cursor object."""
865
866    def __init__(self, dbcnx):
867        """Create a cursor object for the database connection."""
868        self.connection = self._dbcnx = dbcnx
869        self._cnx = dbcnx._cnx
870        self.type_cache = dbcnx.type_cache
871        self._src = self._cnx.source()
872        # the official attribute for describing the result columns
873        self._description = None
874        if self.row_factory is Cursor.row_factory:
875            # the row factory needs to be determined dynamically
876            self.row_factory = None
877        else:
878            self.build_row_factory = None
879        self.rowcount = -1
880        self.arraysize = 1
881        self.lastrowid = None
882
883    def __iter__(self):
884        """Make cursor compatible to the iteration protocol."""
885        return self
886
887    def __enter__(self):
888        """Enter the runtime context for the cursor object."""
889        return self
890
891    def __exit__(self, et, ev, tb):
892        """Exit the runtime context for the cursor object."""
893        self.close()
894
895    def _quote(self, value):
896        """Quote value depending on its type."""
897        if value is None:
898            return 'NULL'
899        if isinstance(value, (Hstore, Json)):
900            value = str(value)
901        if isinstance(value, basestring):
902            if isinstance(value, Binary):
903                value = self._cnx.escape_bytea(value)
904                if bytes is not str:  # Python >= 3.0
905                    value = value.decode('ascii')
906            else:
907                value = self._cnx.escape_string(value)
908            return "'%s'" % (value,)
909        if isinstance(value, float):
910            if isinf(value):
911                return "'-Infinity'" if value < 0 else "'Infinity'"
912            if isnan(value):
913                return "'NaN'"
914            return value
915        if isinstance(value, (int, long, Decimal, Literal)):
916            return value
917        if isinstance(value, datetime):
918            if value.tzinfo:
919                return "'%s'::timestamptz" % (value,)
920            return "'%s'::timestamp" % (value,)
921        if isinstance(value, date):
922            return "'%s'::date" % (value,)
923        if isinstance(value, time):
924            if value.tzinfo:
925                return "'%s'::timetz" % (value,)
926            return "'%s'::time" % value
927        if isinstance(value, timedelta):
928            return "'%s'::interval" % (value,)
929        if isinstance(value, Uuid):
930            return "'%s'::uuid" % (value,)
931        if isinstance(value, list):
932            # Quote value as an ARRAY constructor. This is better than using
933            # an array literal because it carries the information that this is
934            # an array and not a string.  One issue with this syntax is that
935            # you need to add an explicit typecast when passing empty arrays.
936            # The ARRAY keyword is actually only necessary at the top level.
937            if not value:  # exception for empty array
938                return "'{}'"
939            q = self._quote
940            try:
941                return 'ARRAY[%s]' % (','.join(str(q(v)) for v in value),)
942            except UnicodeEncodeError:  # Python 2 with non-ascii values
943                return u'ARRAY[%s]' % (','.join(unicode(q(v)) for v in value),)
944        if isinstance(value, tuple):
945            # Quote as a ROW constructor.  This is better than using a record
946            # literal because it carries the information that this is a record
947            # and not a string.  We don't use the keyword ROW in order to make
948            # this usable with the IN syntax as well.  It is only necessary
949            # when the records has a single column which is not really useful.
950            q = self._quote
951            try:
952                return '(%s)' % (','.join(str(q(v)) for v in value),)
953            except UnicodeEncodeError:  # Python 2 with non-ascii values
954                return u'(%s)' % (','.join(unicode(q(v)) for v in value),)
955        try:
956            value = value.__pg_repr__()
957        except AttributeError:
958            raise InterfaceError(
959                'Do not know how to adapt type %s' % (type(value),))
960        if isinstance(value, (tuple, list)):
961            value = self._quote(value)
962        return value
963
964    def _quoteparams(self, string, parameters):
965        """Quote parameters.
966
967        This function works for both mappings and sequences.
968
969        The function should be used even when there are no parameters,
970        so that we have a consistent behavior regarding percent signs.
971        """
972        if not parameters:
973            try:
974                return string % ()  # unescape literal quotes if possible
975            except (TypeError, ValueError):
976                return string  # silently accept unescaped quotes
977        if isinstance(parameters, dict):
978            parameters = _quotedict(parameters)
979            parameters.quote = self._quote
980        else:
981            parameters = tuple(map(self._quote, parameters))
982        return string % parameters
983
984    def _make_description(self, info):
985        """Make the description tuple for the given field info."""
986        name, typ, size, mod = info[1:]
987        type_code = self.type_cache[typ]
988        if mod > 0:
989            mod -= 4
990        if type_code == 'numeric':
991            precision, scale = mod >> 16, mod & 0xffff
992            size = precision
993        else:
994            if not size:
995                size = type_code.size
996            if size == -1:
997                size = mod
998            precision = scale = None
999        return CursorDescription(name, type_code,
1000            None, size, precision, scale, None)
1001
1002    @property
1003    def description(self):
1004        """Read-only attribute describing the result columns."""
1005        descr = self._description
1006        if self._description is True:
1007            make = self._make_description
1008            descr = [make(info) for info in self._src.listinfo()]
1009            self._description = descr
1010        return descr
1011
1012    @property
1013    def colnames(self):
1014        """Unofficial convenience method for getting the column names."""
1015        return [d[0] for d in self.description]
1016
1017    @property
1018    def coltypes(self):
1019        """Unofficial convenience method for getting the column types."""
1020        return [d[1] for d in self.description]
1021
1022    def close(self):
1023        """Close the cursor object."""
1024        self._src.close()
1025
1026    def execute(self, operation, parameters=None):
1027        """Prepare and execute a database operation (query or command)."""
1028        # The parameters may also be specified as list of tuples to e.g.
1029        # insert multiple rows in a single operation, but this kind of
1030        # usage is deprecated.  We make several plausibility checks because
1031        # tuples can also be passed with the meaning of ROW constructors.
1032        if (parameters and isinstance(parameters, list)
1033                and len(parameters) > 1
1034                and all(isinstance(p, tuple) for p in parameters)
1035                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
1036            return self.executemany(operation, parameters)
1037        else:
1038            # not a list of tuples
1039            return self.executemany(operation, [parameters])
1040
1041    def executemany(self, operation, seq_of_parameters):
1042        """Prepare operation and execute it against a parameter sequence."""
1043        if not seq_of_parameters:
1044            # don't do anything without parameters
1045            return
1046        self._description = None
1047        self.rowcount = -1
1048        # first try to execute all queries
1049        rowcount = 0
1050        sql = "BEGIN"
1051        try:
1052            if not self._dbcnx._tnx and not self._dbcnx.autocommit:
1053                try:
1054                    self._src.execute(sql)
1055                except DatabaseError:
1056                    raise  # database provides error message
1057                except Exception:
1058                    raise _op_error("Can't start transaction")
1059                else:
1060                    self._dbcnx._tnx = True
1061            for parameters in seq_of_parameters:
1062                sql = operation
1063                sql = self._quoteparams(sql, parameters)
1064                rows = self._src.execute(sql)
1065                if rows:  # true if not DML
1066                    rowcount += rows
1067                else:
1068                    self.rowcount = -1
1069        except DatabaseError:
1070            raise  # database provides error message
1071        except Error as err:
1072            raise _db_error(
1073                "Error in '%s': '%s' " % (sql, err), InterfaceError)
1074        except Exception as err:
1075            raise _op_error("Internal error in '%s': %s" % (sql, err))
1076        # then initialize result raw count and description
1077        if self._src.resulttype == RESULT_DQL:
1078            self._description = True  # fetch on demand
1079            self.rowcount = self._src.ntuples
1080            self.lastrowid = None
1081            if self.build_row_factory:
1082                self.row_factory = self.build_row_factory()
1083        else:
1084            self.rowcount = rowcount
1085            self.lastrowid = self._src.oidstatus()
1086        # return the cursor object, so you can write statements such as
1087        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
1088        return self
1089
1090    def fetchone(self):
1091        """Fetch the next row of a query result set."""
1092        res = self.fetchmany(1, False)
1093        try:
1094            return res[0]
1095        except IndexError:
1096            return None
1097
1098    def fetchall(self):
1099        """Fetch all (remaining) rows of a query result."""
1100        return self.fetchmany(-1, False)
1101
1102    def fetchmany(self, size=None, keep=False):
1103        """Fetch the next set of rows of a query result.
1104
1105        The number of rows to fetch per call is specified by the
1106        size parameter. If it is not given, the cursor's arraysize
1107        determines the number of rows to be fetched. If you set
1108        the keep parameter to true, this is kept as new arraysize.
1109        """
1110        if size is None:
1111            size = self.arraysize
1112        if keep:
1113            self.arraysize = size
1114        try:
1115            result = self._src.fetch(size)
1116        except DatabaseError:
1117            raise
1118        except Error as err:
1119            raise _db_error(str(err))
1120        typecast = self.type_cache.typecast
1121        return [self.row_factory([typecast(value, typ)
1122            for typ, value in zip(self.coltypes, row)]) for row in result]
1123
1124    def callproc(self, procname, parameters=None):
1125        """Call a stored database procedure with the given name.
1126
1127        The sequence of parameters must contain one entry for each input
1128        argument that the procedure expects. The result of the call is the
1129        same as this input sequence; replacement of output and input/output
1130        parameters in the return value is currently not supported.
1131
1132        The procedure may also provide a result set as output. These can be
1133        requested through the standard fetch methods of the cursor.
1134        """
1135        n = parameters and len(parameters) or 0
1136        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
1137        self.execute(query, parameters)
1138        return parameters
1139
1140    def copy_from(self, stream, table,
1141            format=None, sep=None, null=None, size=None, columns=None):
1142        """Copy data from an input stream to the specified table.
1143
1144        The input stream can be a file-like object with a read() method or
1145        it can also be an iterable returning a row or multiple rows of input
1146        on each iteration.
1147
1148        The format must be text, csv or binary. The sep option sets the
1149        column separator (delimiter) used in the non binary formats.
1150        The null option sets the textual representation of NULL in the input.
1151
1152        The size option sets the size of the buffer used when reading data
1153        from file-like objects.
1154
1155        The copy operation can be restricted to a subset of columns. If no
1156        columns are specified, all of them will be copied.
1157        """
1158        binary_format = format == 'binary'
1159        try:
1160            read = stream.read
1161        except AttributeError:
1162            if size:
1163                raise ValueError("Size must only be set for file-like objects")
1164            if binary_format:
1165                input_type = bytes
1166                type_name = 'byte strings'
1167            else:
1168                input_type = basestring
1169                type_name = 'strings'
1170
1171            if isinstance(stream, basestring):
1172                if not isinstance(stream, input_type):
1173                    raise ValueError("The input must be %s" % (type_name,))
1174                if not binary_format:
1175                    if isinstance(stream, str):
1176                        if not stream.endswith('\n'):
1177                            stream += '\n'
1178                    else:
1179                        if not stream.endswith(b'\n'):
1180                            stream += b'\n'
1181
1182                def chunks():
1183                    yield stream
1184
1185            elif isinstance(stream, Iterable):
1186
1187                def chunks():
1188                    for chunk in stream:
1189                        if not isinstance(chunk, input_type):
1190                            raise ValueError(
1191                                "Input stream must consist of %s"
1192                                % (type_name,))
1193                        if isinstance(chunk, str):
1194                            if not chunk.endswith('\n'):
1195                                chunk += '\n'
1196                        else:
1197                            if not chunk.endswith(b'\n'):
1198                                chunk += b'\n'
1199                        yield chunk
1200
1201            else:
1202                raise TypeError("Need an input stream to copy from")
1203        else:
1204            if size is None:
1205                size = 8192
1206            elif not isinstance(size, int):
1207                raise TypeError("The size option must be an integer")
1208            if size > 0:
1209
1210                def chunks():
1211                    while True:
1212                        buffer = read(size)
1213                        yield buffer
1214                        if not buffer or len(buffer) < size:
1215                            break
1216
1217            else:
1218
1219                def chunks():
1220                    yield read()
1221
1222        if not table or not isinstance(table, basestring):
1223            raise TypeError("Need a table to copy to")
1224        if table.lower().startswith('select'):
1225            raise ValueError("Must specify a table, not a query")
1226        else:
1227            table = '"%s"' % (table,)
1228        operation = ['copy %s' % (table,)]
1229        options = []
1230        params = []
1231        if format is not None:
1232            if not isinstance(format, basestring):
1233                raise TypeError("The format option must be be a string")
1234            if format not in ('text', 'csv', 'binary'):
1235                raise ValueError("Invalid format")
1236            options.append('format %s' % (format,))
1237        if sep is not None:
1238            if not isinstance(sep, basestring):
1239                raise TypeError("The sep option must be a string")
1240            if format == 'binary':
1241                raise ValueError(
1242                    "The sep option is not allowed with binary format")
1243            if len(sep) != 1:
1244                raise ValueError(
1245                    "The sep option must be a single one-byte character")
1246            options.append('delimiter %s')
1247            params.append(sep)
1248        if null is not None:
1249            if not isinstance(null, basestring):
1250                raise TypeError("The null option must be a string")
1251            options.append('null %s')
1252            params.append(null)
1253        if columns:
1254            if not isinstance(columns, basestring):
1255                columns = ','.join('"%s"' % (col,) for col in columns)
1256            operation.append('(%s)' % (columns,))
1257        operation.append("from stdin")
1258        if options:
1259            operation.append('(%s)' % (','.join(options),))
1260        operation = ' '.join(operation)
1261
1262        putdata = self._src.putdata
1263        self.execute(operation, params)
1264
1265        try:
1266            for chunk in chunks():
1267                putdata(chunk)
1268        except BaseException as error:
1269            self.rowcount = -1
1270            # the following call will re-raise the error
1271            putdata(error)
1272        else:
1273            self.rowcount = putdata(None)
1274
1275        # return the cursor object, so you can chain operations
1276        return self
1277
1278    def copy_to(self, stream, table,
1279            format=None, sep=None, null=None, decode=None, columns=None):
1280        """Copy data from the specified table to an output stream.
1281
1282        The output stream can be a file-like object with a write() method or
1283        it can also be None, in which case the method will return a generator
1284        yielding a row on each iteration.
1285
1286        Output will be returned as byte strings unless you set decode to true.
1287
1288        Note that you can also use a select query instead of the table name.
1289
1290        The format must be text, csv or binary. The sep option sets the
1291        column separator (delimiter) used in the non binary formats.
1292        The null option sets the textual representation of NULL in the output.
1293
1294        The copy operation can be restricted to a subset of columns. If no
1295        columns are specified, all of them will be copied.
1296        """
1297        binary_format = format == 'binary'
1298        if stream is not None:
1299            try:
1300                write = stream.write
1301            except AttributeError:
1302                raise TypeError("Need an output stream to copy to")
1303        if not table or not isinstance(table, basestring):
1304            raise TypeError("Need a table to copy to")
1305        if table.lower().startswith('select'):
1306            if columns:
1307                raise ValueError("Columns must be specified in the query")
1308            table = '(%s)' % (table,)
1309        else:
1310            table = '"%s"' % (table,)
1311        operation = ['copy %s' % (table,)]
1312        options = []
1313        params = []
1314        if format is not None:
1315            if not isinstance(format, basestring):
1316                raise TypeError("The format option must be a string")
1317            if format not in ('text', 'csv', 'binary'):
1318                raise ValueError("Invalid format")
1319            options.append('format %s' % (format,))
1320        if sep is not None:
1321            if not isinstance(sep, basestring):
1322                raise TypeError("The sep option must be a string")
1323            if binary_format:
1324                raise ValueError(
1325                    "The sep option is not allowed with binary format")
1326            if len(sep) != 1:
1327                raise ValueError(
1328                    "The sep option must be a single one-byte character")
1329            options.append('delimiter %s')
1330            params.append(sep)
1331        if null is not None:
1332            if not isinstance(null, basestring):
1333                raise TypeError("The null option must be a string")
1334            options.append('null %s')
1335            params.append(null)
1336        if decode is None:
1337            if format == 'binary':
1338                decode = False
1339            else:
1340                decode = str is unicode
1341        else:
1342            if not isinstance(decode, (int, bool)):
1343                raise TypeError("The decode option must be a boolean")
1344            if decode and binary_format:
1345                raise ValueError(
1346                    "The decode option is not allowed with binary format")
1347        if columns:
1348            if not isinstance(columns, basestring):
1349                columns = ','.join('"%s"' % (col,) for col in columns)
1350            operation.append('(%s)' % (columns,))
1351
1352        operation.append("to stdout")
1353        if options:
1354            operation.append('(%s)' % (','.join(options),))
1355        operation = ' '.join(operation)
1356
1357        getdata = self._src.getdata
1358        self.execute(operation, params)
1359
1360        def copy():
1361            self.rowcount = 0
1362            while True:
1363                row = getdata(decode)
1364                if isinstance(row, int):
1365                    if self.rowcount != row:
1366                        self.rowcount = row
1367                    break
1368                self.rowcount += 1
1369                yield row
1370
1371        if stream is None:
1372            # no input stream, return the generator
1373            return copy()
1374
1375        # write the rows to the file-like input stream
1376        for row in copy():
1377            write(row)
1378
1379        # return the cursor object, so you can chain operations
1380        return self
1381
1382    def __next__(self):
1383        """Return the next row (support for the iteration protocol)."""
1384        res = self.fetchone()
1385        if res is None:
1386            raise StopIteration
1387        return res
1388
1389    # Note that since Python 2.6 the iterator protocol uses __next()__
1390    # instead of next(), we keep it only for backward compatibility of pgdb.
1391    next = __next__
1392
1393    @staticmethod
1394    def nextset():
1395        """Not supported."""
1396        raise NotSupportedError("The nextset() method is not supported")
1397
1398    @staticmethod
1399    def setinputsizes(sizes):
1400        """Not supported."""
1401        pass  # unsupported, but silently passed
1402
1403    @staticmethod
1404    def setoutputsize(size, column=0):
1405        """Not supported."""
1406        pass  # unsupported, but silently passed
1407
1408    @staticmethod
1409    def row_factory(row):
1410        """Process rows before they are returned.
1411
1412        You can overwrite this statically with a custom row factory, or
1413        you can build a row factory dynamically with build_row_factory().
1414
1415        For example, you can create a Cursor class that returns rows as
1416        Python dictionaries like this:
1417
1418            class DictCursor(pgdb.Cursor):
1419
1420                def row_factory(self, row):
1421                    return {desc[0]: value
1422                        for desc, value in zip(self.description, row)}
1423
1424            cur = DictCursor(con)  # get one DictCursor instance or
1425            con.cursor_type = DictCursor  # always use DictCursor instances
1426        """
1427        raise NotImplementedError
1428
1429    def build_row_factory(self):
1430        """Build a row factory based on the current description.
1431
1432        This implementation builds a row factory for creating named tuples.
1433        You can overwrite this method if you want to dynamically create
1434        different row factories whenever the column description changes.
1435        """
1436        names = self.colnames
1437        if names:
1438            return _row_factory(tuple(names))
1439
1440
1441CursorDescription = namedtuple('CursorDescription',
1442    ['name', 'type_code', 'display_size', 'internal_size',
1443     'precision', 'scale', 'null_ok'])
1444
1445
1446### Connection Objects
1447
1448class Connection(object):
1449    """Connection object."""
1450
1451    # expose the exceptions as attributes on the connection object
1452    Error = Error
1453    Warning = Warning
1454    InterfaceError = InterfaceError
1455    DatabaseError = DatabaseError
1456    InternalError = InternalError
1457    OperationalError = OperationalError
1458    ProgrammingError = ProgrammingError
1459    IntegrityError = IntegrityError
1460    DataError = DataError
1461    NotSupportedError = NotSupportedError
1462
1463    def __init__(self, cnx):
1464        """Create a database connection object."""
1465        self._cnx = cnx  # connection
1466        self._tnx = False  # transaction state
1467        self.type_cache = TypeCache(cnx)
1468        self.cursor_type = Cursor
1469        self.autocommit = False
1470        try:
1471            self._cnx.source()
1472        except Exception:
1473            raise _op_error("Invalid connection")
1474
1475    def __enter__(self):
1476        """Enter the runtime context for the connection object.
1477
1478        The runtime context can be used for running transactions.
1479
1480        This also starts a transaction in autocommit mode.
1481        """
1482        if self.autocommit:
1483            try:
1484                self._cnx.source().execute("BEGIN")
1485            except DatabaseError:
1486                raise  # database provides error message
1487            except Exception:
1488                raise _op_error("Can't start transaction")
1489            else:
1490                self._tnx = True
1491        return self
1492
1493    def __exit__(self, et, ev, tb):
1494        """Exit the runtime context for the connection object.
1495
1496        This does not close the connection, but it ends a transaction.
1497        """
1498        if et is None and ev is None and tb is None:
1499            self.commit()
1500        else:
1501            self.rollback()
1502
1503    def close(self):
1504        """Close the connection object."""
1505        if self._cnx:
1506            if self._tnx:
1507                try:
1508                    self.rollback()
1509                except DatabaseError:
1510                    pass
1511            self._cnx.close()
1512            self._cnx = None
1513        else:
1514            raise _op_error("Connection has been closed")
1515
1516    @property
1517    def closed(self):
1518        """Check whether the connection has been closed or is broken."""
1519        try:
1520            return not self._cnx or self._cnx.status != 1
1521        except TypeError:
1522            return True
1523
1524    def commit(self):
1525        """Commit any pending transaction to the database."""
1526        if self._cnx:
1527            if self._tnx:
1528                self._tnx = False
1529                try:
1530                    self._cnx.source().execute("COMMIT")
1531                except DatabaseError:
1532                    raise  # database provides error message
1533                except Exception:
1534                    raise _op_error("Can't commit transaction")
1535        else:
1536            raise _op_error("Connection has been closed")
1537
1538    def rollback(self):
1539        """Roll back to the start of any pending transaction."""
1540        if self._cnx:
1541            if self._tnx:
1542                self._tnx = False
1543                try:
1544                    self._cnx.source().execute("ROLLBACK")
1545                except DatabaseError:
1546                    raise  # database provides error message
1547                except Exception:
1548                    raise _op_error("Can't rollback transaction")
1549        else:
1550            raise _op_error("Connection has been closed")
1551
1552    def cursor(self):
1553        """Return a new cursor object using the connection."""
1554        if self._cnx:
1555            try:
1556                return self.cursor_type(self)
1557            except Exception:
1558                raise _op_error("Invalid connection")
1559        else:
1560            raise _op_error("Connection has been closed")
1561
1562    if shortcutmethods:  # otherwise do not implement and document this
1563
1564        def execute(self, operation, params=None):
1565            """Shortcut method to run an operation on an implicit cursor."""
1566            cursor = self.cursor()
1567            cursor.execute(operation, params)
1568            return cursor
1569
1570        def executemany(self, operation, param_seq):
1571            """Shortcut method to run an operation against a sequence."""
1572            cursor = self.cursor()
1573            cursor.executemany(operation, param_seq)
1574            return cursor
1575
1576
1577### Module Interface
1578
1579_connect = connect
1580
1581def connect(dsn=None,
1582        user=None, password=None,
1583        host=None, database=None, **kwargs):
1584    """Connect to a database."""
1585    # first get params from DSN
1586    dbport = -1
1587    dbhost = ""
1588    dbname = ""
1589    dbuser = ""
1590    dbpasswd = ""
1591    dbopt = ""
1592    try:
1593        params = dsn.split(":")
1594        dbhost = params[0]
1595        dbname = params[1]
1596        dbuser = params[2]
1597        dbpasswd = params[3]
1598        dbopt = params[4]
1599    except (AttributeError, IndexError, TypeError):
1600        pass
1601
1602    # override if necessary
1603    if user is not None:
1604        dbuser = user
1605    if password is not None:
1606        dbpasswd = password
1607    if database is not None:
1608        dbname = database
1609    if host is not None:
1610        try:
1611            params = host.split(":")
1612            dbhost = params[0]
1613            dbport = int(params[1])
1614        except (AttributeError, IndexError, TypeError, ValueError):
1615            pass
1616
1617    # empty host is localhost
1618    if dbhost == "":
1619        dbhost = None
1620    if dbuser == "":
1621        dbuser = None
1622
1623    # pass keyword arguments as connection info string
1624    if kwargs:
1625        kwargs = list(kwargs.items())
1626        if '=' in dbname:
1627            dbname = [dbname]
1628        else:
1629            kwargs.insert(0, ('dbname', dbname))
1630            dbname = []
1631        for kw, value in kwargs:
1632            value = str(value)
1633            if not value or ' ' in value:
1634                value = "'%s'" % (value.replace(
1635                    "'", "\\'").replace('\\', '\\\\'),)
1636            dbname.append('%s=%s' % (kw, value))
1637        dbname = ' '.join(dbname)
1638
1639    # open the connection
1640    cnx = _connect(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd)
1641    return Connection(cnx)
1642
1643
1644### Types Handling
1645
1646class Type(frozenset):
1647    """Type class for a couple of PostgreSQL data types.
1648
1649    PostgreSQL is object-oriented: types are dynamic.
1650    We must thus use type names as internal type codes.
1651    """
1652
1653    def __new__(cls, values):
1654        if isinstance(values, basestring):
1655            values = values.split()
1656        return super(Type, cls).__new__(cls, values)
1657
1658    def __eq__(self, other):
1659        if isinstance(other, basestring):
1660            if other.startswith('_'):
1661                other = other[1:]
1662            return other in self
1663        else:
1664            return super(Type, self).__eq__(other)
1665
1666    def __ne__(self, other):
1667        if isinstance(other, basestring):
1668            if other.startswith('_'):
1669                other = other[1:]
1670            return other not in self
1671        else:
1672            return super(Type, self).__ne__(other)
1673
1674
1675class ArrayType:
1676    """Type class for PostgreSQL array types."""
1677
1678    def __eq__(self, other):
1679        if isinstance(other, basestring):
1680            return other.startswith('_')
1681        else:
1682            return isinstance(other, ArrayType)
1683
1684    def __ne__(self, other):
1685        if isinstance(other, basestring):
1686            return not other.startswith('_')
1687        else:
1688            return not isinstance(other, ArrayType)
1689
1690
1691class RecordType:
1692    """Type class for PostgreSQL record types."""
1693
1694    def __eq__(self, other):
1695        if isinstance(other, TypeCode):
1696            return other.type == 'c'
1697        elif isinstance(other, basestring):
1698            return other == 'record'
1699        else:
1700            return isinstance(other, RecordType)
1701
1702    def __ne__(self, other):
1703        if isinstance(other, TypeCode):
1704            return other.type != 'c'
1705        elif isinstance(other, basestring):
1706            return other != 'record'
1707        else:
1708            return not isinstance(other, RecordType)
1709
1710
1711# Mandatory type objects defined by DB-API 2 specs:
1712
1713STRING = Type('char bpchar name text varchar')
1714BINARY = Type('bytea')
1715NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1716DATETIME = Type('date time timetz timestamp timestamptz interval'
1717    ' abstime reltime')  # these are very old
1718ROWID = Type('oid')
1719
1720
1721# Additional type objects (more specific):
1722
1723BOOL = Type('bool')
1724SMALLINT = Type('int2')
1725INTEGER = Type('int2 int4 int8 serial')
1726LONG = Type('int8')
1727FLOAT = Type('float4 float8')
1728NUMERIC = Type('numeric')
1729MONEY = Type('money')
1730DATE = Type('date')
1731TIME = Type('time timetz')
1732TIMESTAMP = Type('timestamp timestamptz')
1733INTERVAL = Type('interval')
1734UUID = Type('uuid')
1735HSTORE = Type('hstore')
1736JSON = Type('json jsonb')
1737
1738# Type object for arrays (also equate to their base types):
1739
1740ARRAY = ArrayType()
1741
1742# Type object for records (encompassing all composite types):
1743
1744RECORD = RecordType()
1745
1746
1747# Mandatory type helpers defined by DB-API 2 specs:
1748
1749def Date(year, month, day):
1750    """Construct an object holding a date value."""
1751    return date(year, month, day)
1752
1753
1754def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None):
1755    """Construct an object holding a time value."""
1756    return time(hour, minute, second, microsecond, tzinfo)
1757
1758
1759def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0,
1760        tzinfo=None):
1761    """Construct an object holding a time stamp value."""
1762    return datetime(year, month, day, hour, minute, second, microsecond, tzinfo)
1763
1764
1765def DateFromTicks(ticks):
1766    """Construct an object holding a date value from the given ticks value."""
1767    return Date(*localtime(ticks)[:3])
1768
1769
1770def TimeFromTicks(ticks):
1771    """Construct an object holding a time value from the given ticks value."""
1772    return Time(*localtime(ticks)[3:6])
1773
1774
1775def TimestampFromTicks(ticks):
1776    """Construct an object holding a time stamp from the given ticks value."""
1777    return Timestamp(*localtime(ticks)[:6])
1778
1779
1780class Binary(bytes):
1781    """Construct an object capable of holding a binary (long) string value."""
1782
1783
1784# Additional type helpers for PyGreSQL:
1785
1786def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0):
1787    """Construct an object holding a time interval value."""
1788    return timedelta(days, hours=hours, minutes=minutes, seconds=seconds,
1789        microseconds=microseconds)
1790
1791
1792Uuid = Uuid  # Construct an object holding a UUID value
1793
1794
1795class Hstore(dict):
1796    """Wrapper class for marking hstore values."""
1797
1798    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
1799    _re_escape = regex(r'(["\\])')
1800
1801    @classmethod
1802    def _quote(cls, s):
1803        if s is None:
1804            return 'NULL'
1805        if not s:
1806            return '""'
1807        quote = cls._re_quote.search(s)
1808        s = cls._re_escape.sub(r'\\\1', s)
1809        if quote:
1810            s = '"%s"' % (s,)
1811        return s
1812
1813    def __str__(self):
1814        q = self._quote
1815        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
1816
1817
1818class Json:
1819    """Construct a wrapper for holding an object serializable to JSON."""
1820
1821    def __init__(self, obj, encode=None):
1822        self.obj = obj
1823        self.encode = encode or jsonencode
1824
1825    def __str__(self):
1826        obj = self.obj
1827        if isinstance(obj, basestring):
1828            return obj
1829        return self.encode(obj)
1830
1831
1832class Literal:
1833    """Construct a wrapper for holding a literal SQL string."""
1834
1835    def __init__(self, sql):
1836        self.sql = sql
1837
1838    def __str__(self):
1839        return self.sql
1840
1841    __pg_repr__ = __str__
1842
1843
1844# If run as script, print some information:
1845
1846if __name__ == '__main__':
1847    print('PyGreSQL version', version)
1848    print('')
1849    print(__doc__)
Note: See TracBrowser for help on using the repository browser.