source: trunk/pgdb.py @ 894

Last change on this file since 894 was 894, checked in by cito, 3 years ago

Cache the namedtuple classes used for query result rows

  • Property svn:keywords set to Id
File size: 59.7 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 894 2016-09-23 14:04:06Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function, division
67
68from _pg import *
69
70__version__ = version
71
72from datetime import date, time, datetime, timedelta, tzinfo
73from time import localtime
74from decimal import Decimal
75from uuid import UUID as Uuid
76from math import isnan, isinf
77from collections import namedtuple, Iterable
78from functools import partial
79from re import compile as regex
80from json import loads as jsondecode, dumps as jsonencode
81
82try:
83    long
84except NameError:  # Python >= 3.0
85    long = int
86
87try:
88    unicode
89except NameError:  # Python >= 3.0
90    unicode = str
91
92try:
93    basestring
94except NameError:  # Python >= 3.0
95    basestring = (str, bytes)
96
97try:
98    from functools import lru_cache
99except ImportError:  # Python < 3.2
100    from functools import update_wrapper
101    try:
102        from _thread import RLock
103    except ImportError:
104        class RLock:  # for builds without threads
105            def __enter__(self): pass
106
107            def __exit__(self, exctype, excinst, exctb): pass
108
109    def lru_cache(maxsize=128):
110        """Simplified functools.lru_cache decorator for one argument."""
111
112        def decorator(function):
113            sentinel = object()
114            cache = {}
115            get = cache.get
116            lock = RLock()
117            root = []
118            root_full = [root, False]
119            root[:] = [root, root, None, None]
120
121            if maxsize == 0:
122
123                def wrapper(arg):
124                    res = function(arg)
125                    return res
126
127            elif maxsize is None:
128
129                def wrapper(arg):
130                    res = get(arg, sentinel)
131                    if res is not sentinel:
132                        return res
133                    res = function(arg)
134                    cache[arg] = res
135                    return res
136
137            else:
138
139                def wrapper(arg):
140                    with lock:
141                        link = get(arg)
142                        if link is not None:
143                            root = root_full[0]
144                            prev, next, _arg, res = link
145                            prev[1] = next
146                            next[0] = prev
147                            last = root[0]
148                            last[1] = root[0] = link
149                            link[0] = last
150                            link[1] = root
151                            return res
152                    res = function(arg)
153                    with lock:
154                        root, full = root_full
155                        if arg in cache:
156                            pass
157                        elif full:
158                            oldroot = root
159                            oldroot[2] = arg
160                            oldroot[3] = res
161                            root = root_full[0] = oldroot[1]
162                            oldarg = root[2]
163                            oldres = root[3]  # keep reference
164                            root[2] = root[3] = None
165                            del cache[oldarg]
166                            cache[arg] = oldroot
167                        else:
168                            last = root[0]
169                            link = [last, root, arg, res]
170                            last[1] = root[0] = cache[arg] = link
171                            if len(cache) >= maxsize:
172                                root_full[1] = True
173                    return res
174
175            wrapper.__wrapped__ = function
176            return update_wrapper(wrapper, function)
177
178        return decorator
179
180
181### Module Constants
182
183# compliant with DB API 2.0
184apilevel = '2.0'
185
186# module may be shared, but not connections
187threadsafety = 1
188
189# this module use extended python format codes
190paramstyle = 'pyformat'
191
192# shortcut methods have been excluded from DB API 2 and
193# are not recommended by the DB SIG, but they can be handy
194shortcutmethods = 1
195
196
197### Internal Type Handling
198
199try:
200    from inspect import signature
201except ImportError:  # Python < 3.3
202    from inspect import getargspec
203
204    def get_args(func):
205        return getargspec(func).args
206else:
207
208    def get_args(func):
209        return list(signature(func).parameters)
210
211try:
212    from datetime import timezone
213except ImportError:  # Python < 3.2
214
215    class timezone(tzinfo):
216        """Simple timezone implementation."""
217
218        def __init__(self, offset, name=None):
219            self.offset = offset
220            if not name:
221                minutes = self.offset.days * 1440 + self.offset.seconds // 60
222                if minutes < 0:
223                    hours, minutes = divmod(-minutes, 60)
224                    hours = -hours
225                else:
226                    hours, minutes = divmod(minutes, 60)
227                name = 'UTC%+03d:%02d' % (hours, minutes)
228            self.name = name
229
230        def utcoffset(self, dt):
231            return self.offset
232
233        def tzname(self, dt):
234            return self.name
235
236        def dst(self, dt):
237            return None
238
239    timezone.utc = timezone(timedelta(0), 'UTC')
240
241    _has_timezone = False
242else:
243    _has_timezone = True
244
245# time zones used in Postgres timestamptz output
246_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
247    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
248    UCT='+0000', UTC='+0000', WET='+0000')
249
250
251def _timezone_as_offset(tz):
252    if tz.startswith(('+', '-')):
253        if len(tz) < 5:
254            return tz + '00'
255        return tz.replace(':', '')
256    return _timezones.get(tz, '+0000')
257
258
259def _get_timezone(tz):
260    tz = _timezone_as_offset(tz)
261    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
262    if tz[0] == '-':
263        minutes = -minutes
264    return timezone(timedelta(minutes=minutes), tz)
265
266
267def decimal_type(decimal_type=None):
268    """Get or set global type to be used for decimal values.
269
270    Note that connections cache cast functions. To be sure a global change
271    is picked up by a running connection, call con.type_cache.reset_typecast().
272    """
273    global Decimal
274    if decimal_type is not None:
275        Decimal = decimal_type
276        set_typecast('numeric', decimal_type)
277    return Decimal
278
279
280def cast_bool(value):
281    """Cast boolean value in database format to bool."""
282    if value:
283        return value[0] in ('t', 'T')
284
285
286def cast_money(value):
287    """Cast money value in database format to Decimal."""
288    if value:
289        value = value.replace('(', '-')
290        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
291
292
293def cast_int2vector(value):
294    """Cast an int2vector value."""
295    return [int(v) for v in value.split()]
296
297
298def cast_date(value, connection):
299    """Cast a date value."""
300    # The output format depends on the server setting DateStyle.  The default
301    # setting ISO and the setting for German are actually unambiguous.  The
302    # order of days and months in the other two settings is however ambiguous,
303    # so at least here we need to consult the setting to properly parse values.
304    if value == '-infinity':
305        return date.min
306    if value == 'infinity':
307        return date.max
308    value = value.split()
309    if value[-1] == 'BC':
310        return date.min
311    value = value[0]
312    if len(value) > 10:
313        return date.max
314    fmt = connection.date_format()
315    return datetime.strptime(value, fmt).date()
316
317
318def cast_time(value):
319    """Cast a time value."""
320    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
321    return datetime.strptime(value, fmt).time()
322
323
324_re_timezone = regex('(.*)([+-].*)')
325
326
327def cast_timetz(value):
328    """Cast a timetz value."""
329    tz = _re_timezone.match(value)
330    if tz:
331        value, tz = tz.groups()
332    else:
333        tz = '+0000'
334    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
335    if _has_timezone:
336        value += _timezone_as_offset(tz)
337        fmt += '%z'
338        return datetime.strptime(value, fmt).timetz()
339    return datetime.strptime(value, fmt).timetz().replace(
340        tzinfo=_get_timezone(tz))
341
342
343def cast_timestamp(value, connection):
344    """Cast a timestamp value."""
345    if value == '-infinity':
346        return datetime.min
347    if value == 'infinity':
348        return datetime.max
349    value = value.split()
350    if value[-1] == 'BC':
351        return datetime.min
352    fmt = connection.date_format()
353    if fmt.endswith('-%Y') and len(value) > 2:
354        value = value[1:5]
355        if len(value[3]) > 4:
356            return datetime.max
357        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
358            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
359    else:
360        if len(value[0]) > 10:
361            return datetime.max
362        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
363    return datetime.strptime(' '.join(value), ' '.join(fmt))
364
365
366def cast_timestamptz(value, connection):
367    """Cast a timestamptz value."""
368    if value == '-infinity':
369        return datetime.min
370    if value == 'infinity':
371        return datetime.max
372    value = value.split()
373    if value[-1] == 'BC':
374        return datetime.min
375    fmt = connection.date_format()
376    if fmt.endswith('-%Y') and len(value) > 2:
377        value = value[1:]
378        if len(value[3]) > 4:
379            return datetime.max
380        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
381            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
382        value, tz = value[:-1], value[-1]
383    else:
384        if fmt.startswith('%Y-'):
385            tz = _re_timezone.match(value[1])
386            if tz:
387                value[1], tz = tz.groups()
388            else:
389                tz = '+0000'
390        else:
391            value, tz = value[:-1], value[-1]
392        if len(value[0]) > 10:
393            return datetime.max
394        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
395    if _has_timezone:
396        value.append(_timezone_as_offset(tz))
397        fmt.append('%z')
398        return datetime.strptime(' '.join(value), ' '.join(fmt))
399    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
400        tzinfo=_get_timezone(tz))
401
402
403_re_interval_sql_standard = regex(
404    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
405    '(?:([+-]?[0-9]+)(?!:) ?)?'
406    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
407
408_re_interval_postgres = regex(
409    '(?:([+-]?[0-9]+) ?years? ?)?'
410    '(?:([+-]?[0-9]+) ?mons? ?)?'
411    '(?:([+-]?[0-9]+) ?days? ?)?'
412    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
413
414_re_interval_postgres_verbose = regex(
415    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
416    '(?:([+-]?[0-9]+) ?mons? ?)?'
417    '(?:([+-]?[0-9]+) ?days? ?)?'
418    '(?:([+-]?[0-9]+) ?hours? ?)?'
419    '(?:([+-]?[0-9]+) ?mins? ?)?'
420    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
421
422_re_interval_iso_8601 = regex(
423    'P(?:([+-]?[0-9]+)Y)?'
424    '(?:([+-]?[0-9]+)M)?'
425    '(?:([+-]?[0-9]+)D)?'
426    '(?:T(?:([+-]?[0-9]+)H)?'
427    '(?:([+-]?[0-9]+)M)?'
428    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
429
430
431def cast_interval(value):
432    """Cast an interval value."""
433    # The output format depends on the server setting IntervalStyle, but it's
434    # not necessary to consult this setting to parse it.  It's faster to just
435    # check all possible formats, and there is no ambiguity here.
436    m = _re_interval_iso_8601.match(value)
437    if m:
438        m = [d or '0' for d in m.groups()]
439        secs_ago = m.pop(5) == '-'
440        m = [int(d) for d in m]
441        years, mons, days, hours, mins, secs, usecs = m
442        if secs_ago:
443            secs = -secs
444            usecs = -usecs
445    else:
446        m = _re_interval_postgres_verbose.match(value)
447        if m:
448            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
449            secs_ago = m.pop(5) == '-'
450            m = [-int(d) for d in m] if ago else [int(d) for d in m]
451            years, mons, days, hours, mins, secs, usecs = m
452            if secs_ago:
453                secs = - secs
454                usecs = -usecs
455        else:
456            m = _re_interval_postgres.match(value)
457            if m and any(m.groups()):
458                m = [d or '0' for d in m.groups()]
459                hours_ago = m.pop(3) == '-'
460                m = [int(d) for d in m]
461                years, mons, days, hours, mins, secs, usecs = m
462                if hours_ago:
463                    hours = -hours
464                    mins = -mins
465                    secs = -secs
466                    usecs = -usecs
467            else:
468                m = _re_interval_sql_standard.match(value)
469                if m and any(m.groups()):
470                    m = [d or '0' for d in m.groups()]
471                    years_ago = m.pop(0) == '-'
472                    hours_ago = m.pop(3) == '-'
473                    m = [int(d) for d in m]
474                    years, mons, days, hours, mins, secs, usecs = m
475                    if years_ago:
476                        years = -years
477                        mons = -mons
478                    if hours_ago:
479                        hours = -hours
480                        mins = -mins
481                        secs = -secs
482                        usecs = -usecs
483                else:
484                    raise ValueError('Cannot parse interval: %s' % value)
485    days += 365 * years + 30 * mons
486    return timedelta(days=days, hours=hours, minutes=mins,
487        seconds=secs, microseconds=usecs)
488
489
490class Typecasts(dict):
491    """Dictionary mapping database types to typecast functions.
492
493    The cast functions get passed the string representation of a value in
494    the database which they need to convert to a Python object.  The
495    passed string will never be None since NULL values are already be
496    handled before the cast function is called.
497    """
498
499    # the default cast functions
500    # (str functions are ignored but have been added for faster access)
501    defaults = {'char': str, 'bpchar': str, 'name': str,
502        'text': str, 'varchar': str,
503        'bool': cast_bool, 'bytea': unescape_bytea,
504        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
505        'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode,
506        'float4': float, 'float8': float,
507        'numeric': Decimal, 'money': cast_money,
508        'date': cast_date, 'interval': cast_interval,
509        'time': cast_time, 'timetz': cast_timetz,
510        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
511        'int2vector': cast_int2vector, 'uuid': Uuid,
512        'anyarray': cast_array, 'record': cast_record}
513
514    connection = None  # will be set in local connection specific instances
515
516    def __missing__(self, typ):
517        """Create a cast function if it is not cached.
518
519        Note that this class never raises a KeyError,
520        but returns None when no special cast function exists.
521        """
522        if not isinstance(typ, str):
523            raise TypeError('Invalid type: %s' % typ)
524        cast = self.defaults.get(typ)
525        if cast:
526            # store default for faster access
527            cast = self._add_connection(cast)
528            self[typ] = cast
529        elif typ.startswith('_'):
530            # create array cast
531            base_cast = self[typ[1:]]
532            cast = self.create_array_cast(base_cast)
533            if base_cast:
534                # store only if base type exists
535                self[typ] = cast
536        return cast
537
538    @staticmethod
539    def _needs_connection(func):
540        """Check if a typecast function needs a connection argument."""
541        try:
542            args = get_args(func)
543        except (TypeError, ValueError):
544            return False
545        else:
546            return 'connection' in args[1:]
547
548    def _add_connection(self, cast):
549        """Add a connection argument to the typecast function if necessary."""
550        if not self.connection or not self._needs_connection(cast):
551            return cast
552        return partial(cast, connection=self.connection)
553
554    def get(self, typ, default=None):
555        """Get the typecast function for the given database type."""
556        return self[typ] or default
557
558    def set(self, typ, cast):
559        """Set a typecast function for the specified database type(s)."""
560        if isinstance(typ, basestring):
561            typ = [typ]
562        if cast is None:
563            for t in typ:
564                self.pop(t, None)
565                self.pop('_%s' % t, None)
566        else:
567            if not callable(cast):
568                raise TypeError("Cast parameter must be callable")
569            for t in typ:
570                self[t] = self._add_connection(cast)
571                self.pop('_%s' % t, None)
572
573    def reset(self, typ=None):
574        """Reset the typecasts for the specified type(s) to their defaults.
575
576        When no type is specified, all typecasts will be reset.
577        """
578        defaults = self.defaults
579        if typ is None:
580            self.clear()
581            self.update(defaults)
582        else:
583            if isinstance(typ, basestring):
584                typ = [typ]
585            for t in typ:
586                cast = defaults.get(t)
587                if cast:
588                    self[t] = self._add_connection(cast)
589                    t = '_%s' % t
590                    cast = defaults.get(t)
591                    if cast:
592                        self[t] = self._add_connection(cast)
593                    else:
594                        self.pop(t, None)
595                else:
596                    self.pop(t, None)
597                    self.pop('_%s' % t, None)
598
599    def create_array_cast(self, basecast):
600        """Create an array typecast for the given base cast."""
601        cast_array = self['anyarray']
602        def cast(v):
603            return cast_array(v, basecast)
604        return cast
605
606    def create_record_cast(self, name, fields, casts):
607        """Create a named record typecast for the given fields and casts."""
608        cast_record = self['record']
609        record = namedtuple(name, fields)
610        def cast(v):
611            return record(*cast_record(v, casts))
612        return cast
613
614
615_typecasts = Typecasts()  # this is the global typecast dictionary
616
617
618def get_typecast(typ):
619    """Get the global typecast function for the given database type(s)."""
620    return _typecasts.get(typ)
621
622
623def set_typecast(typ, cast):
624    """Set a global typecast function for the given database type(s).
625
626    Note that connections cache cast functions. To be sure a global change
627    is picked up by a running connection, call con.type_cache.reset_typecast().
628    """
629    _typecasts.set(typ, cast)
630
631
632def reset_typecast(typ=None):
633    """Reset the global typecasts for the given type(s) to their default.
634
635    When no type is specified, all typecasts will be reset.
636
637    Note that connections cache cast functions. To be sure a global change
638    is picked up by a running connection, call con.type_cache.reset_typecast().
639    """
640    _typecasts.reset(typ)
641
642
643class LocalTypecasts(Typecasts):
644    """Map typecasts, including local composite types, to cast functions."""
645
646    defaults = _typecasts
647
648    connection = None  # will be set in a connection specific instance
649
650    def __missing__(self, typ):
651        """Create a cast function if it is not cached."""
652        if typ.startswith('_'):
653            base_cast = self[typ[1:]]
654            cast = self.create_array_cast(base_cast)
655            if base_cast:
656                self[typ] = cast
657        else:
658            cast = self.defaults.get(typ)
659            if cast:
660                cast = self._add_connection(cast)
661                self[typ] = cast
662            else:
663                fields = self.get_fields(typ)
664                if fields:
665                    casts = [self[field.type] for field in fields]
666                    fields = [field.name for field in fields]
667                    cast = self.create_record_cast(typ, fields, casts)
668                    self[typ] = cast
669        return cast
670
671    def get_fields(self, typ):
672        """Return the fields for the given record type.
673
674        This method will be replaced with a method that looks up the fields
675        using the type cache of the connection.
676        """
677        return []
678
679
680class TypeCode(str):
681    """Class representing the type_code used by the DB-API 2.0.
682
683    TypeCode objects are strings equal to the PostgreSQL type name,
684    but carry some additional information.
685    """
686
687    @classmethod
688    def create(cls, oid, name, len, type, category, delim, relid):
689        """Create a type code for a PostgreSQL data type."""
690        self = cls(name)
691        self.oid = oid
692        self.len = len
693        self.type = type
694        self.category = category
695        self.delim = delim
696        self.relid = relid
697        return self
698
699FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
700
701
702class TypeCache(dict):
703    """Cache for database types.
704
705    This cache maps type OIDs and names to TypeCode strings containing
706    important information on the associated database type.
707    """
708
709    def __init__(self, cnx):
710        """Initialize type cache for connection."""
711        super(TypeCache, self).__init__()
712        self._escape_string = cnx.escape_string
713        self._src = cnx.source()
714        self._typecasts = LocalTypecasts()
715        self._typecasts.get_fields = self.get_fields
716        self._typecasts.connection = cnx
717        if cnx.server_version < 80400:
718            # older remote databases (not officially supported)
719            self._query_pg_type = ("SELECT oid, typname,"
720                " typlen, typtype, null as typcategory, typdelim, typrelid"
721                " FROM pg_type WHERE oid=%s")
722        else:
723            self._query_pg_type = ("SELECT oid, typname,"
724                " typlen, typtype, typcategory, typdelim, typrelid"
725                " FROM pg_type WHERE oid=%s")
726
727    def __missing__(self, key):
728        """Get the type info from the database if it is not cached."""
729        if isinstance(key, int):
730            oid = key
731        else:
732            if '.' not in key and '"' not in key:
733                key = '"%s"' % (key,)
734            oid = "'%s'::regtype" % (self._escape_string(key),)
735        try:
736            self._src.execute(self._query_pg_type % (oid,))
737        except ProgrammingError:
738            res = None
739        else:
740            res = self._src.fetch(1)
741        if not res:
742            raise KeyError('Type %s could not be found' % (key,))
743        res = res[0]
744        type_code = TypeCode.create(int(res[0]), res[1],
745            int(res[2]), res[3], res[4], res[5], int(res[6]))
746        self[type_code.oid] = self[str(type_code)] = type_code
747        return type_code
748
749    def get(self, key, default=None):
750        """Get the type even if it is not cached."""
751        try:
752            return self[key]
753        except KeyError:
754            return default
755
756    def get_fields(self, typ):
757        """Get the names and types of the fields of composite types."""
758        if not isinstance(typ, TypeCode):
759            typ = self.get(typ)
760            if not typ:
761                return None
762        if not typ.relid:
763            return None  # this type is not composite
764        self._src.execute("SELECT attname, atttypid"
765            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
766            " AND NOT attisdropped ORDER BY attnum" % (typ.relid,))
767        return [FieldInfo(name, self.get(int(oid)))
768            for name, oid in self._src.fetch(-1)]
769
770    def get_typecast(self, typ):
771        """Get the typecast function for the given database type."""
772        return self._typecasts.get(typ)
773
774    def set_typecast(self, typ, cast):
775        """Set a typecast function for the specified database type(s)."""
776        self._typecasts.set(typ, cast)
777
778    def reset_typecast(self, typ=None):
779        """Reset the typecast function for the specified database type(s)."""
780        self._typecasts.reset(typ)
781
782    def typecast(self, value, typ):
783        """Cast the given value according to the given database type."""
784        if value is None:
785            # for NULL values, no typecast is necessary
786            return None
787        cast = self.get_typecast(typ)
788        if not cast or cast is str:
789            # no typecast is necessary
790            return value
791        return cast(value)
792
793
794class _quotedict(dict):
795    """Dictionary with auto quoting of its items.
796
797    The quote attribute must be set to the desired quote function.
798    """
799
800    def __getitem__(self, key):
801        return self.quote(super(_quotedict, self).__getitem__(key))
802
803
804### Error Messages
805
806def _db_error(msg, cls=DatabaseError):
807    """Return DatabaseError with empty sqlstate attribute."""
808    error = cls(msg)
809    error.sqlstate = None
810    return error
811
812
813def _op_error(msg):
814    """Return OperationalError."""
815    return _db_error(msg, OperationalError)
816
817
818### Row Tuples
819
820# The result rows for database operations are returned as named tuples
821# by default. Since creating namedtuple classes is a somewhat expensive
822# operation, we cache up to 1024 of these classes by default.
823
824@lru_cache(maxsize=1024)
825def _row_factory(names):
826    """Get a namedtuple factory for row results with the given names."""
827    try:
828        try:
829            return namedtuple('Row', names, rename=True)._make
830        except TypeError:  # Python 2.6 and 3.0 do not support rename
831            names = [v if v.isalnum() else 'column_%d' % (n,)
832                     for n, v in enumerate(names)]
833            return namedtuple('Row', names)._make
834    except ValueError:  # there is still a problem with the field names
835        names = ['column_%d' % (n,) for n in range(len(names))]
836        return namedtuple('Row', names)._make
837
838
839def set_row_factory_size(maxsize):
840    """Change the size of the namedtuple factory cache.
841
842    If maxsize is set to None, the cache can grow without bound.
843    """
844    global _row_factory
845    _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__)
846
847
848### Cursor Object
849
850class Cursor(object):
851    """Cursor object."""
852
853    def __init__(self, dbcnx):
854        """Create a cursor object for the database connection."""
855        self.connection = self._dbcnx = dbcnx
856        self._cnx = dbcnx._cnx
857        self.type_cache = dbcnx.type_cache
858        self._src = self._cnx.source()
859        # the official attribute for describing the result columns
860        self._description = None
861        if self.row_factory is Cursor.row_factory:
862            # the row factory needs to be determined dynamically
863            self.row_factory = None
864        else:
865            self.build_row_factory = None
866        self.rowcount = -1
867        self.arraysize = 1
868        self.lastrowid = None
869
870    def __iter__(self):
871        """Make cursor compatible to the iteration protocol."""
872        return self
873
874    def __enter__(self):
875        """Enter the runtime context for the cursor object."""
876        return self
877
878    def __exit__(self, et, ev, tb):
879        """Exit the runtime context for the cursor object."""
880        self.close()
881
882    def _quote(self, value):
883        """Quote value depending on its type."""
884        if value is None:
885            return 'NULL'
886        if isinstance(value, (Hstore, Json)):
887            value = str(value)
888        if isinstance(value, basestring):
889            if isinstance(value, Binary):
890                value = self._cnx.escape_bytea(value)
891                if bytes is not str:  # Python >= 3.0
892                    value = value.decode('ascii')
893            else:
894                value = self._cnx.escape_string(value)
895            return "'%s'" % (value,)
896        if isinstance(value, float):
897            if isinf(value):
898                return "'-Infinity'" if value < 0 else "'Infinity'"
899            if isnan(value):
900                return "'NaN'"
901            return value
902        if isinstance(value, (int, long, Decimal, Literal)):
903            return value
904        if isinstance(value, datetime):
905            if value.tzinfo:
906                return "'%s'::timestamptz" % (value,)
907            return "'%s'::timestamp" % (value,)
908        if isinstance(value, date):
909            return "'%s'::date" % (value,)
910        if isinstance(value, time):
911            if value.tzinfo:
912                return "'%s'::timetz" % (value,)
913            return "'%s'::time" % value
914        if isinstance(value, timedelta):
915            return "'%s'::interval" % (value,)
916        if isinstance(value, Uuid):
917            return "'%s'::uuid" % (value,)
918        if isinstance(value, list):
919            # Quote value as an ARRAY constructor. This is better than using
920            # an array literal because it carries the information that this is
921            # an array and not a string.  One issue with this syntax is that
922            # you need to add an explicit typecast when passing empty arrays.
923            # The ARRAY keyword is actually only necessary at the top level.
924            if not value:  # exception for empty array
925                return "'{}'"
926            q = self._quote
927            try:
928                return 'ARRAY[%s]' % (','.join(str(q(v)) for v in value),)
929            except UnicodeEncodeError:  # Python 2 with non-ascii values
930                return u'ARRAY[%s]' % (','.join(unicode(q(v)) for v in value),)
931        if isinstance(value, tuple):
932            # Quote as a ROW constructor.  This is better than using a record
933            # literal because it carries the information that this is a record
934            # and not a string.  We don't use the keyword ROW in order to make
935            # this usable with the IN syntax as well.  It is only necessary
936            # when the records has a single column which is not really useful.
937            q = self._quote
938            try:
939                return '(%s)' % (','.join(str(q(v)) for v in value),)
940            except UnicodeEncodeError:  # Python 2 with non-ascii values
941                return u'(%s)' % (','.join(unicode(q(v)) for v in value),)
942        try:
943            value = value.__pg_repr__()
944        except AttributeError:
945            raise InterfaceError(
946                'Do not know how to adapt type %s' % (type(value),))
947        if isinstance(value, (tuple, list)):
948            value = self._quote(value)
949        return value
950
951    def _quoteparams(self, string, parameters):
952        """Quote parameters.
953
954        This function works for both mappings and sequences.
955
956        The function should be used even when there are no parameters,
957        so that we have a consistent behavior regarding percent signs.
958        """
959        if not parameters:
960            try:
961                return string % ()  # unescape literal quotes if possible
962            except (TypeError, ValueError):
963                return string  # silently accept unescaped quotes
964        if isinstance(parameters, dict):
965            parameters = _quotedict(parameters)
966            parameters.quote = self._quote
967        else:
968            parameters = tuple(map(self._quote, parameters))
969        return string % parameters
970
971    def _make_description(self, info):
972        """Make the description tuple for the given field info."""
973        name, typ, size, mod = info[1:]
974        type_code = self.type_cache[typ]
975        if mod > 0:
976            mod -= 4
977        if type_code == 'numeric':
978            precision, scale = mod >> 16, mod & 0xffff
979            size = precision
980        else:
981            if not size:
982                size = type_code.size
983            if size == -1:
984                size = mod
985            precision = scale = None
986        return CursorDescription(name, type_code,
987            None, size, precision, scale, None)
988
989    @property
990    def description(self):
991        """Read-only attribute describing the result columns."""
992        descr = self._description
993        if self._description is True:
994            make = self._make_description
995            descr = [make(info) for info in self._src.listinfo()]
996            self._description = descr
997        return descr
998
999    @property
1000    def colnames(self):
1001        """Unofficial convenience method for getting the column names."""
1002        return [d[0] for d in self.description]
1003
1004    @property
1005    def coltypes(self):
1006        """Unofficial convenience method for getting the column types."""
1007        return [d[1] for d in self.description]
1008
1009    def close(self):
1010        """Close the cursor object."""
1011        self._src.close()
1012
1013    def execute(self, operation, parameters=None):
1014        """Prepare and execute a database operation (query or command)."""
1015        # The parameters may also be specified as list of tuples to e.g.
1016        # insert multiple rows in a single operation, but this kind of
1017        # usage is deprecated.  We make several plausibility checks because
1018        # tuples can also be passed with the meaning of ROW constructors.
1019        if (parameters and isinstance(parameters, list)
1020                and len(parameters) > 1
1021                and all(isinstance(p, tuple) for p in parameters)
1022                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
1023            return self.executemany(operation, parameters)
1024        else:
1025            # not a list of tuples
1026            return self.executemany(operation, [parameters])
1027
1028    def executemany(self, operation, seq_of_parameters):
1029        """Prepare operation and execute it against a parameter sequence."""
1030        if not seq_of_parameters:
1031            # don't do anything without parameters
1032            return
1033        self._description = None
1034        self.rowcount = -1
1035        # first try to execute all queries
1036        rowcount = 0
1037        sql = "BEGIN"
1038        try:
1039            if not self._dbcnx._tnx:
1040                try:
1041                    self._src.execute(sql)
1042                except DatabaseError:
1043                    raise  # database provides error message
1044                except Exception:
1045                    raise _op_error("Can't start transaction")
1046                self._dbcnx._tnx = True
1047            for parameters in seq_of_parameters:
1048                sql = operation
1049                sql = self._quoteparams(sql, parameters)
1050                rows = self._src.execute(sql)
1051                if rows:  # true if not DML
1052                    rowcount += rows
1053                else:
1054                    self.rowcount = -1
1055        except DatabaseError:
1056            raise  # database provides error message
1057        except Error as err:
1058            raise _db_error(
1059                "Error in '%s': '%s' " % (sql, err), InterfaceError)
1060        except Exception as err:
1061            raise _op_error("Internal error in '%s': %s" % (sql, err))
1062        # then initialize result raw count and description
1063        if self._src.resulttype == RESULT_DQL:
1064            self._description = True  # fetch on demand
1065            self.rowcount = self._src.ntuples
1066            self.lastrowid = None
1067            if self.build_row_factory:
1068                self.row_factory = self.build_row_factory()
1069        else:
1070            self.rowcount = rowcount
1071            self.lastrowid = self._src.oidstatus()
1072        # return the cursor object, so you can write statements such as
1073        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
1074        return self
1075
1076    def fetchone(self):
1077        """Fetch the next row of a query result set."""
1078        res = self.fetchmany(1, False)
1079        try:
1080            return res[0]
1081        except IndexError:
1082            return None
1083
1084    def fetchall(self):
1085        """Fetch all (remaining) rows of a query result."""
1086        return self.fetchmany(-1, False)
1087
1088    def fetchmany(self, size=None, keep=False):
1089        """Fetch the next set of rows of a query result.
1090
1091        The number of rows to fetch per call is specified by the
1092        size parameter. If it is not given, the cursor's arraysize
1093        determines the number of rows to be fetched. If you set
1094        the keep parameter to true, this is kept as new arraysize.
1095        """
1096        if size is None:
1097            size = self.arraysize
1098        if keep:
1099            self.arraysize = size
1100        try:
1101            result = self._src.fetch(size)
1102        except DatabaseError:
1103            raise
1104        except Error as err:
1105            raise _db_error(str(err))
1106        typecast = self.type_cache.typecast
1107        return [self.row_factory([typecast(value, typ)
1108            for typ, value in zip(self.coltypes, row)]) for row in result]
1109
1110    def callproc(self, procname, parameters=None):
1111        """Call a stored database procedure with the given name.
1112
1113        The sequence of parameters must contain one entry for each input
1114        argument that the procedure expects. The result of the call is the
1115        same as this input sequence; replacement of output and input/output
1116        parameters in the return value is currently not supported.
1117
1118        The procedure may also provide a result set as output. These can be
1119        requested through the standard fetch methods of the cursor.
1120        """
1121        n = parameters and len(parameters) or 0
1122        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
1123        self.execute(query, parameters)
1124        return parameters
1125
1126    def copy_from(self, stream, table,
1127            format=None, sep=None, null=None, size=None, columns=None):
1128        """Copy data from an input stream to the specified table.
1129
1130        The input stream can be a file-like object with a read() method or
1131        it can also be an iterable returning a row or multiple rows of input
1132        on each iteration.
1133
1134        The format must be text, csv or binary. The sep option sets the
1135        column separator (delimiter) used in the non binary formats.
1136        The null option sets the textual representation of NULL in the input.
1137
1138        The size option sets the size of the buffer used when reading data
1139        from file-like objects.
1140
1141        The copy operation can be restricted to a subset of columns. If no
1142        columns are specified, all of them will be copied.
1143        """
1144        binary_format = format == 'binary'
1145        try:
1146            read = stream.read
1147        except AttributeError:
1148            if size:
1149                raise ValueError("Size must only be set for file-like objects")
1150            if binary_format:
1151                input_type = bytes
1152                type_name = 'byte strings'
1153            else:
1154                input_type = basestring
1155                type_name = 'strings'
1156
1157            if isinstance(stream, basestring):
1158                if not isinstance(stream, input_type):
1159                    raise ValueError("The input must be %s" % (type_name,))
1160                if not binary_format:
1161                    if isinstance(stream, str):
1162                        if not stream.endswith('\n'):
1163                            stream += '\n'
1164                    else:
1165                        if not stream.endswith(b'\n'):
1166                            stream += b'\n'
1167
1168                def chunks():
1169                    yield stream
1170
1171            elif isinstance(stream, Iterable):
1172
1173                def chunks():
1174                    for chunk in stream:
1175                        if not isinstance(chunk, input_type):
1176                            raise ValueError(
1177                                "Input stream must consist of %s"
1178                                % (type_name,))
1179                        if isinstance(chunk, str):
1180                            if not chunk.endswith('\n'):
1181                                chunk += '\n'
1182                        else:
1183                            if not chunk.endswith(b'\n'):
1184                                chunk += b'\n'
1185                        yield chunk
1186
1187            else:
1188                raise TypeError("Need an input stream to copy from")
1189        else:
1190            if size is None:
1191                size = 8192
1192            elif not isinstance(size, int):
1193                raise TypeError("The size option must be an integer")
1194            if size > 0:
1195
1196                def chunks():
1197                    while True:
1198                        buffer = read(size)
1199                        yield buffer
1200                        if not buffer or len(buffer) < size:
1201                            break
1202
1203            else:
1204
1205                def chunks():
1206                    yield read()
1207
1208        if not table or not isinstance(table, basestring):
1209            raise TypeError("Need a table to copy to")
1210        if table.lower().startswith('select'):
1211                raise ValueError("Must specify a table, not a query")
1212        else:
1213            table = '"%s"' % (table,)
1214        operation = ['copy %s' % (table,)]
1215        options = []
1216        params = []
1217        if format is not None:
1218            if not isinstance(format, basestring):
1219                raise TypeError("The frmat option must be be a string")
1220            if format not in ('text', 'csv', 'binary'):
1221                raise ValueError("Invalid format")
1222            options.append('format %s' % (format,))
1223        if sep is not None:
1224            if not isinstance(sep, basestring):
1225                raise TypeError("The sep option must be a string")
1226            if format == 'binary':
1227                raise ValueError(
1228                    "The sep option is not allowed with binary format")
1229            if len(sep) != 1:
1230                raise ValueError(
1231                    "The sep option must be a single one-byte character")
1232            options.append('delimiter %s')
1233            params.append(sep)
1234        if null is not None:
1235            if not isinstance(null, basestring):
1236                raise TypeError("The null option must be a string")
1237            options.append('null %s')
1238            params.append(null)
1239        if columns:
1240            if not isinstance(columns, basestring):
1241                columns = ','.join('"%s"' % (col,) for col in columns)
1242            operation.append('(%s)' % (columns,))
1243        operation.append("from stdin")
1244        if options:
1245            operation.append('(%s)' % (','.join(options),))
1246        operation = ' '.join(operation)
1247
1248        putdata = self._src.putdata
1249        self.execute(operation, params)
1250
1251        try:
1252            for chunk in chunks():
1253                putdata(chunk)
1254        except BaseException as error:
1255            self.rowcount = -1
1256            # the following call will re-raise the error
1257            putdata(error)
1258        else:
1259            self.rowcount = putdata(None)
1260
1261        # return the cursor object, so you can chain operations
1262        return self
1263
1264    def copy_to(self, stream, table,
1265            format=None, sep=None, null=None, decode=None, columns=None):
1266        """Copy data from the specified table to an output stream.
1267
1268        The output stream can be a file-like object with a write() method or
1269        it can also be None, in which case the method will return a generator
1270        yielding a row on each iteration.
1271
1272        Output will be returned as byte strings unless you set decode to true.
1273
1274        Note that you can also use a select query instead of the table name.
1275
1276        The format must be text, csv or binary. The sep option sets the
1277        column separator (delimiter) used in the non binary formats.
1278        The null option sets the textual representation of NULL in the output.
1279
1280        The copy operation can be restricted to a subset of columns. If no
1281        columns are specified, all of them will be copied.
1282        """
1283        binary_format = format == 'binary'
1284        if stream is not None:
1285            try:
1286                write = stream.write
1287            except AttributeError:
1288                raise TypeError("Need an output stream to copy to")
1289        if not table or not isinstance(table, basestring):
1290            raise TypeError("Need a table to copy to")
1291        if table.lower().startswith('select'):
1292            if columns:
1293                raise ValueError("Columns must be specified in the query")
1294            table = '(%s)' % (table,)
1295        else:
1296            table = '"%s"' % (table,)
1297        operation = ['copy %s' % (table,)]
1298        options = []
1299        params = []
1300        if format is not None:
1301            if not isinstance(format, basestring):
1302                raise TypeError("The format option must be a string")
1303            if format not in ('text', 'csv', 'binary'):
1304                raise ValueError("Invalid format")
1305            options.append('format %s' % (format,))
1306        if sep is not None:
1307            if not isinstance(sep, basestring):
1308                raise TypeError("The sep option must be a string")
1309            if binary_format:
1310                raise ValueError(
1311                    "The sep option is not allowed with binary format")
1312            if len(sep) != 1:
1313                raise ValueError(
1314                    "The sep option must be a single one-byte character")
1315            options.append('delimiter %s')
1316            params.append(sep)
1317        if null is not None:
1318            if not isinstance(null, basestring):
1319                raise TypeError("The null option must be a string")
1320            options.append('null %s')
1321            params.append(null)
1322        if decode is None:
1323            if format == 'binary':
1324                decode = False
1325            else:
1326                decode = str is unicode
1327        else:
1328            if not isinstance(decode, (int, bool)):
1329                raise TypeError("The decode option must be a boolean")
1330            if decode and binary_format:
1331                raise ValueError(
1332                    "The decode option is not allowed with binary format")
1333        if columns:
1334            if not isinstance(columns, basestring):
1335                columns = ','.join('"%s"' % (col,) for col in columns)
1336            operation.append('(%s)' % (columns,))
1337
1338        operation.append("to stdout")
1339        if options:
1340            operation.append('(%s)' % (','.join(options),))
1341        operation = ' '.join(operation)
1342
1343        getdata = self._src.getdata
1344        self.execute(operation, params)
1345
1346        def copy():
1347            self.rowcount = 0
1348            while True:
1349                row = getdata(decode)
1350                if isinstance(row, int):
1351                    if self.rowcount != row:
1352                        self.rowcount = row
1353                    break
1354                self.rowcount += 1
1355                yield row
1356
1357        if stream is None:
1358            # no input stream, return the generator
1359            return copy()
1360
1361        # write the rows to the file-like input stream
1362        for row in copy():
1363            write(row)
1364
1365        # return the cursor object, so you can chain operations
1366        return self
1367
1368    def __next__(self):
1369        """Return the next row (support for the iteration protocol)."""
1370        res = self.fetchone()
1371        if res is None:
1372            raise StopIteration
1373        return res
1374
1375    # Note that since Python 2.6 the iterator protocol uses __next()__
1376    # instead of next(), we keep it only for backward compatibility of pgdb.
1377    next = __next__
1378
1379    @staticmethod
1380    def nextset():
1381        """Not supported."""
1382        raise NotSupportedError("The nextset() method is not supported")
1383
1384    @staticmethod
1385    def setinputsizes(sizes):
1386        """Not supported."""
1387        pass  # unsupported, but silently passed
1388
1389    @staticmethod
1390    def setoutputsize(size, column=0):
1391        """Not supported."""
1392        pass  # unsupported, but silently passed
1393
1394    @staticmethod
1395    def row_factory(row):
1396        """Process rows before they are returned.
1397
1398        You can overwrite this statically with a custom row factory, or
1399        you can build a row factory dynamically with build_row_factory().
1400
1401        For example, you can create a Cursor class that returns rows as
1402        Python dictionaries like this:
1403
1404            class DictCursor(pgdb.Cursor):
1405
1406                def row_factory(self, row):
1407                    return {desc[0]: value
1408                        for desc, value in zip(self.description, row)}
1409
1410            cur = DictCursor(con)  # get one DictCursor instance or
1411            con.cursor_type = DictCursor  # always use DictCursor instances
1412        """
1413        raise NotImplementedError
1414
1415    def build_row_factory(self):
1416        """Build a row factory based on the current description.
1417
1418        This implementation builds a row factory for creating named tuples.
1419        You can overwrite this method if you want to dynamically create
1420        different row factories whenever the column description changes.
1421        """
1422        names = self.colnames
1423        if names:
1424            return _row_factory(tuple(names))
1425
1426
1427
1428CursorDescription = namedtuple('CursorDescription',
1429    ['name', 'type_code', 'display_size', 'internal_size',
1430     'precision', 'scale', 'null_ok'])
1431
1432
1433### Connection Objects
1434
1435class Connection(object):
1436    """Connection object."""
1437
1438    # expose the exceptions as attributes on the connection object
1439    Error = Error
1440    Warning = Warning
1441    InterfaceError = InterfaceError
1442    DatabaseError = DatabaseError
1443    InternalError = InternalError
1444    OperationalError = OperationalError
1445    ProgrammingError = ProgrammingError
1446    IntegrityError = IntegrityError
1447    DataError = DataError
1448    NotSupportedError = NotSupportedError
1449
1450    def __init__(self, cnx):
1451        """Create a database connection object."""
1452        self._cnx = cnx  # connection
1453        self._tnx = False  # transaction state
1454        self.type_cache = TypeCache(cnx)
1455        self.cursor_type = Cursor
1456        try:
1457            self._cnx.source()
1458        except Exception:
1459            raise _op_error("Invalid connection")
1460
1461    def __enter__(self):
1462        """Enter the runtime context for the connection object.
1463
1464        The runtime context can be used for running transactions.
1465        """
1466        return self
1467
1468    def __exit__(self, et, ev, tb):
1469        """Exit the runtime context for the connection object.
1470
1471        This does not close the connection, but it ends a transaction.
1472        """
1473        if et is None and ev is None and tb is None:
1474            self.commit()
1475        else:
1476            self.rollback()
1477
1478    def close(self):
1479        """Close the connection object."""
1480        if self._cnx:
1481            if self._tnx:
1482                try:
1483                    self.rollback()
1484                except DatabaseError:
1485                    pass
1486            self._cnx.close()
1487            self._cnx = None
1488        else:
1489            raise _op_error("Connection has been closed")
1490
1491    @property
1492    def closed(self):
1493        """Check whether the connection has been closed or is broken."""
1494        try:
1495            return not self._cnx or self._cnx.status != 1
1496        except TypeError:
1497            return True
1498
1499    def commit(self):
1500        """Commit any pending transaction to the database."""
1501        if self._cnx:
1502            if self._tnx:
1503                self._tnx = False
1504                try:
1505                    self._cnx.source().execute("COMMIT")
1506                except DatabaseError:
1507                    raise
1508                except Exception:
1509                    raise _op_error("Can't commit")
1510        else:
1511            raise _op_error("Connection has been closed")
1512
1513    def rollback(self):
1514        """Roll back to the start of any pending transaction."""
1515        if self._cnx:
1516            if self._tnx:
1517                self._tnx = False
1518                try:
1519                    self._cnx.source().execute("ROLLBACK")
1520                except DatabaseError:
1521                    raise
1522                except Exception:
1523                    raise _op_error("Can't rollback")
1524        else:
1525            raise _op_error("Connection has been closed")
1526
1527    def cursor(self):
1528        """Return a new cursor object using the connection."""
1529        if self._cnx:
1530            try:
1531                return self.cursor_type(self)
1532            except Exception:
1533                raise _op_error("Invalid connection")
1534        else:
1535            raise _op_error("Connection has been closed")
1536
1537    if shortcutmethods:  # otherwise do not implement and document this
1538
1539        def execute(self, operation, params=None):
1540            """Shortcut method to run an operation on an implicit cursor."""
1541            cursor = self.cursor()
1542            cursor.execute(operation, params)
1543            return cursor
1544
1545        def executemany(self, operation, param_seq):
1546            """Shortcut method to run an operation against a sequence."""
1547            cursor = self.cursor()
1548            cursor.executemany(operation, param_seq)
1549            return cursor
1550
1551
1552### Module Interface
1553
1554_connect = connect
1555
1556def connect(dsn=None,
1557        user=None, password=None,
1558        host=None, database=None, **kwargs):
1559    """Connect to a database."""
1560    # first get params from DSN
1561    dbport = -1
1562    dbhost = ""
1563    dbname = ""
1564    dbuser = ""
1565    dbpasswd = ""
1566    dbopt = ""
1567    try:
1568        params = dsn.split(":")
1569        dbhost = params[0]
1570        dbname = params[1]
1571        dbuser = params[2]
1572        dbpasswd = params[3]
1573        dbopt = params[4]
1574    except (AttributeError, IndexError, TypeError):
1575        pass
1576
1577    # override if necessary
1578    if user is not None:
1579        dbuser = user
1580    if password is not None:
1581        dbpasswd = password
1582    if database is not None:
1583        dbname = database
1584    if host is not None:
1585        try:
1586            params = host.split(":")
1587            dbhost = params[0]
1588            dbport = int(params[1])
1589        except (AttributeError, IndexError, TypeError, ValueError):
1590            pass
1591
1592    # empty host is localhost
1593    if dbhost == "":
1594        dbhost = None
1595    if dbuser == "":
1596        dbuser = None
1597
1598    # pass keyword arguments as connection info string
1599    if kwargs:
1600        kwargs = list(kwargs.items())
1601        if '=' in dbname:
1602            dbname = [dbname]
1603        else:
1604            kwargs.insert(0, ('dbname', dbname))
1605            dbname = []
1606        for kw, value in kwargs:
1607            value = str(value)
1608            if not value or ' ' in value:
1609                value = "'%s'" % (value.replace(
1610                    "'", "\\'").replace('\\', '\\\\'),)
1611            dbname.append('%s=%s' % (kw, value))
1612        dbname = ' '.join(dbname)
1613
1614    # open the connection
1615    cnx = _connect(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd)
1616    return Connection(cnx)
1617
1618
1619### Types Handling
1620
1621class Type(frozenset):
1622    """Type class for a couple of PostgreSQL data types.
1623
1624    PostgreSQL is object-oriented: types are dynamic.
1625    We must thus use type names as internal type codes.
1626    """
1627
1628    def __new__(cls, values):
1629        if isinstance(values, basestring):
1630            values = values.split()
1631        return super(Type, cls).__new__(cls, values)
1632
1633    def __eq__(self, other):
1634        if isinstance(other, basestring):
1635            if other.startswith('_'):
1636                other = other[1:]
1637            return other in self
1638        else:
1639            return super(Type, self).__eq__(other)
1640
1641    def __ne__(self, other):
1642        if isinstance(other, basestring):
1643            if other.startswith('_'):
1644                other = other[1:]
1645            return other not in self
1646        else:
1647            return super(Type, self).__ne__(other)
1648
1649
1650class ArrayType:
1651    """Type class for PostgreSQL array types."""
1652
1653    def __eq__(self, other):
1654        if isinstance(other, basestring):
1655            return other.startswith('_')
1656        else:
1657            return isinstance(other, ArrayType)
1658
1659    def __ne__(self, other):
1660        if isinstance(other, basestring):
1661            return not other.startswith('_')
1662        else:
1663            return not isinstance(other, ArrayType)
1664
1665
1666class RecordType:
1667    """Type class for PostgreSQL record types."""
1668
1669    def __eq__(self, other):
1670        if isinstance(other, TypeCode):
1671            return other.type == 'c'
1672        elif isinstance(other, basestring):
1673            return other == 'record'
1674        else:
1675            return isinstance(other, RecordType)
1676
1677    def __ne__(self, other):
1678        if isinstance(other, TypeCode):
1679            return other.type != 'c'
1680        elif isinstance(other, basestring):
1681            return other != 'record'
1682        else:
1683            return not isinstance(other, RecordType)
1684
1685
1686# Mandatory type objects defined by DB-API 2 specs:
1687
1688STRING = Type('char bpchar name text varchar')
1689BINARY = Type('bytea')
1690NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1691DATETIME = Type('date time timetz timestamp timestamptz interval'
1692    ' abstime reltime')  # these are very old
1693ROWID = Type('oid')
1694
1695
1696# Additional type objects (more specific):
1697
1698BOOL = Type('bool')
1699SMALLINT = Type('int2')
1700INTEGER = Type('int2 int4 int8 serial')
1701LONG = Type('int8')
1702FLOAT = Type('float4 float8')
1703NUMERIC = Type('numeric')
1704MONEY = Type('money')
1705DATE = Type('date')
1706TIME = Type('time timetz')
1707TIMESTAMP = Type('timestamp timestamptz')
1708INTERVAL = Type('interval')
1709UUID = Type('uuid')
1710HSTORE = Type('hstore')
1711JSON = Type('json jsonb')
1712
1713# Type object for arrays (also equate to their base types):
1714
1715ARRAY = ArrayType()
1716
1717# Type object for records (encompassing all composite types):
1718
1719RECORD = RecordType()
1720
1721
1722# Mandatory type helpers defined by DB-API 2 specs:
1723
1724def Date(year, month, day):
1725    """Construct an object holding a date value."""
1726    return date(year, month, day)
1727
1728
1729def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None):
1730    """Construct an object holding a time value."""
1731    return time(hour, minute, second, microsecond, tzinfo)
1732
1733
1734def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0,
1735        tzinfo=None):
1736    """Construct an object holding a time stamp value."""
1737    return datetime(year, month, day, hour, minute, second, microsecond, tzinfo)
1738
1739
1740def DateFromTicks(ticks):
1741    """Construct an object holding a date value from the given ticks value."""
1742    return Date(*localtime(ticks)[:3])
1743
1744
1745def TimeFromTicks(ticks):
1746    """Construct an object holding a time value from the given ticks value."""
1747    return Time(*localtime(ticks)[3:6])
1748
1749
1750def TimestampFromTicks(ticks):
1751    """Construct an object holding a time stamp from the given ticks value."""
1752    return Timestamp(*localtime(ticks)[:6])
1753
1754
1755class Binary(bytes):
1756    """Construct an object capable of holding a binary (long) string value."""
1757
1758
1759# Additional type helpers for PyGreSQL:
1760
1761def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0):
1762    """Construct an object holding a time inverval value."""
1763    return timedelta(days, hours=hours, minutes=minutes, seconds=seconds,
1764        microseconds=microseconds)
1765
1766
1767Uuid = Uuid  # Construct an object holding a UUID value
1768
1769
1770class Hstore(dict):
1771    """Wrapper class for marking hstore values."""
1772
1773    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
1774    _re_escape = regex(r'(["\\])')
1775
1776    @classmethod
1777    def _quote(cls, s):
1778        if s is None:
1779            return 'NULL'
1780        if not s:
1781            return '""'
1782        quote = cls._re_quote.search(s)
1783        s = cls._re_escape.sub(r'\\\1', s)
1784        if quote:
1785            s = '"%s"' % (s,)
1786        return s
1787
1788    def __str__(self):
1789        q = self._quote
1790        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
1791
1792
1793class Json:
1794    """Construct a wrapper for holding an object serializable to JSON."""
1795
1796    def __init__(self, obj, encode=None):
1797        self.obj = obj
1798        self.encode = encode or jsonencode
1799
1800    def __str__(self):
1801        obj = self.obj
1802        if isinstance(obj, basestring):
1803            return obj
1804        return self.encode(obj)
1805
1806
1807class Literal:
1808    """Construct a wrapper for holding a literal SQL string."""
1809
1810    def __init__(self, sql):
1811        self.sql = sql
1812
1813    def __str__(self):
1814        return self.sql
1815
1816    __pg_repr__ = __str__
1817
1818# If run as script, print some information:
1819
1820if __name__ == '__main__':
1821    print('PyGreSQL version', version)
1822    print('')
1823    print(__doc__)
Note: See TracBrowser for help on using the repository browser.