source: trunk/pgdb.py

Last change on this file was 901, checked in by cito, 22 months ago

Improve creation of named tuples in Python 2.6 and 3.0

  • Property svn:keywords set to Id
File size: 59.8 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 901 2017-01-06 12:25:02Z darcy $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function, division
67
68from _pg import *
69
70__version__ = version
71
72from datetime import date, time, datetime, timedelta, tzinfo
73from time import localtime
74from decimal import Decimal
75from uuid import UUID as Uuid
76from math import isnan, isinf
77from collections import namedtuple, Iterable
78from keyword import iskeyword
79from functools import partial
80from re import compile as regex
81from json import loads as jsondecode, dumps as jsonencode
82
83try:
84    long
85except NameError:  # Python >= 3.0
86    long = int
87
88try:
89    unicode
90except NameError:  # Python >= 3.0
91    unicode = str
92
93try:
94    basestring
95except NameError:  # Python >= 3.0
96    basestring = (str, bytes)
97
98try:
99    from functools import lru_cache
100except ImportError:  # Python < 3.2
101    from functools import update_wrapper
102    try:
103        from _thread import RLock
104    except ImportError:
105        class RLock:  # for builds without threads
106            def __enter__(self): pass
107
108            def __exit__(self, exctype, excinst, exctb): pass
109
110    def lru_cache(maxsize=128):
111        """Simplified functools.lru_cache decorator for one argument."""
112
113        def decorator(function):
114            sentinel = object()
115            cache = {}
116            get = cache.get
117            lock = RLock()
118            root = []
119            root_full = [root, False]
120            root[:] = [root, root, None, None]
121
122            if maxsize == 0:
123
124                def wrapper(arg):
125                    res = function(arg)
126                    return res
127
128            elif maxsize is None:
129
130                def wrapper(arg):
131                    res = get(arg, sentinel)
132                    if res is not sentinel:
133                        return res
134                    res = function(arg)
135                    cache[arg] = res
136                    return res
137
138            else:
139
140                def wrapper(arg):
141                    with lock:
142                        link = get(arg)
143                        if link is not None:
144                            root = root_full[0]
145                            prev, next, _arg, res = link
146                            prev[1] = next
147                            next[0] = prev
148                            last = root[0]
149                            last[1] = root[0] = link
150                            link[0] = last
151                            link[1] = root
152                            return res
153                    res = function(arg)
154                    with lock:
155                        root, full = root_full
156                        if arg in cache:
157                            pass
158                        elif full:
159                            oldroot = root
160                            oldroot[2] = arg
161                            oldroot[3] = res
162                            root = root_full[0] = oldroot[1]
163                            oldarg = root[2]
164                            oldres = root[3]  # keep reference
165                            root[2] = root[3] = None
166                            del cache[oldarg]
167                            cache[arg] = oldroot
168                        else:
169                            last = root[0]
170                            link = [last, root, arg, res]
171                            last[1] = root[0] = cache[arg] = link
172                            if len(cache) >= maxsize:
173                                root_full[1] = True
174                    return res
175
176            wrapper.__wrapped__ = function
177            return update_wrapper(wrapper, function)
178
179        return decorator
180
181
182### Module Constants
183
184# compliant with DB API 2.0
185apilevel = '2.0'
186
187# module may be shared, but not connections
188threadsafety = 1
189
190# this module use extended python format codes
191paramstyle = 'pyformat'
192
193# shortcut methods have been excluded from DB API 2 and
194# are not recommended by the DB SIG, but they can be handy
195shortcutmethods = 1
196
197
198### Internal Type Handling
199
200try:
201    from inspect import signature
202except ImportError:  # Python < 3.3
203    from inspect import getargspec
204
205    def get_args(func):
206        return getargspec(func).args
207else:
208
209    def get_args(func):
210        return list(signature(func).parameters)
211
212try:
213    from datetime import timezone
214except ImportError:  # Python < 3.2
215
216    class timezone(tzinfo):
217        """Simple timezone implementation."""
218
219        def __init__(self, offset, name=None):
220            self.offset = offset
221            if not name:
222                minutes = self.offset.days * 1440 + self.offset.seconds // 60
223                if minutes < 0:
224                    hours, minutes = divmod(-minutes, 60)
225                    hours = -hours
226                else:
227                    hours, minutes = divmod(minutes, 60)
228                name = 'UTC%+03d:%02d' % (hours, minutes)
229            self.name = name
230
231        def utcoffset(self, dt):
232            return self.offset
233
234        def tzname(self, dt):
235            return self.name
236
237        def dst(self, dt):
238            return None
239
240    timezone.utc = timezone(timedelta(0), 'UTC')
241
242    _has_timezone = False
243else:
244    _has_timezone = True
245
246# time zones used in Postgres timestamptz output
247_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
248    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
249    UCT='+0000', UTC='+0000', WET='+0000')
250
251
252def _timezone_as_offset(tz):
253    if tz.startswith(('+', '-')):
254        if len(tz) < 5:
255            return tz + '00'
256        return tz.replace(':', '')
257    return _timezones.get(tz, '+0000')
258
259
260def _get_timezone(tz):
261    tz = _timezone_as_offset(tz)
262    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
263    if tz[0] == '-':
264        minutes = -minutes
265    return timezone(timedelta(minutes=minutes), tz)
266
267
268def decimal_type(decimal_type=None):
269    """Get or set global type to be used for decimal values.
270
271    Note that connections cache cast functions. To be sure a global change
272    is picked up by a running connection, call con.type_cache.reset_typecast().
273    """
274    global Decimal
275    if decimal_type is not None:
276        Decimal = decimal_type
277        set_typecast('numeric', decimal_type)
278    return Decimal
279
280
281def cast_bool(value):
282    """Cast boolean value in database format to bool."""
283    if value:
284        return value[0] in ('t', 'T')
285
286
287def cast_money(value):
288    """Cast money value in database format to Decimal."""
289    if value:
290        value = value.replace('(', '-')
291        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
292
293
294def cast_int2vector(value):
295    """Cast an int2vector value."""
296    return [int(v) for v in value.split()]
297
298
299def cast_date(value, connection):
300    """Cast a date value."""
301    # The output format depends on the server setting DateStyle.  The default
302    # setting ISO and the setting for German are actually unambiguous.  The
303    # order of days and months in the other two settings is however ambiguous,
304    # so at least here we need to consult the setting to properly parse values.
305    if value == '-infinity':
306        return date.min
307    if value == 'infinity':
308        return date.max
309    value = value.split()
310    if value[-1] == 'BC':
311        return date.min
312    value = value[0]
313    if len(value) > 10:
314        return date.max
315    fmt = connection.date_format()
316    return datetime.strptime(value, fmt).date()
317
318
319def cast_time(value):
320    """Cast a time value."""
321    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
322    return datetime.strptime(value, fmt).time()
323
324
325_re_timezone = regex('(.*)([+-].*)')
326
327
328def cast_timetz(value):
329    """Cast a timetz value."""
330    tz = _re_timezone.match(value)
331    if tz:
332        value, tz = tz.groups()
333    else:
334        tz = '+0000'
335    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
336    if _has_timezone:
337        value += _timezone_as_offset(tz)
338        fmt += '%z'
339        return datetime.strptime(value, fmt).timetz()
340    return datetime.strptime(value, fmt).timetz().replace(
341        tzinfo=_get_timezone(tz))
342
343
344def cast_timestamp(value, connection):
345    """Cast a timestamp value."""
346    if value == '-infinity':
347        return datetime.min
348    if value == 'infinity':
349        return datetime.max
350    value = value.split()
351    if value[-1] == 'BC':
352        return datetime.min
353    fmt = connection.date_format()
354    if fmt.endswith('-%Y') and len(value) > 2:
355        value = value[1:5]
356        if len(value[3]) > 4:
357            return datetime.max
358        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
359            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
360    else:
361        if len(value[0]) > 10:
362            return datetime.max
363        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
364    return datetime.strptime(' '.join(value), ' '.join(fmt))
365
366
367def cast_timestamptz(value, connection):
368    """Cast a timestamptz value."""
369    if value == '-infinity':
370        return datetime.min
371    if value == 'infinity':
372        return datetime.max
373    value = value.split()
374    if value[-1] == 'BC':
375        return datetime.min
376    fmt = connection.date_format()
377    if fmt.endswith('-%Y') and len(value) > 2:
378        value = value[1:]
379        if len(value[3]) > 4:
380            return datetime.max
381        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
382            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
383        value, tz = value[:-1], value[-1]
384    else:
385        if fmt.startswith('%Y-'):
386            tz = _re_timezone.match(value[1])
387            if tz:
388                value[1], tz = tz.groups()
389            else:
390                tz = '+0000'
391        else:
392            value, tz = value[:-1], value[-1]
393        if len(value[0]) > 10:
394            return datetime.max
395        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
396    if _has_timezone:
397        value.append(_timezone_as_offset(tz))
398        fmt.append('%z')
399        return datetime.strptime(' '.join(value), ' '.join(fmt))
400    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
401        tzinfo=_get_timezone(tz))
402
403
404_re_interval_sql_standard = regex(
405    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
406    '(?:([+-]?[0-9]+)(?!:) ?)?'
407    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
408
409_re_interval_postgres = regex(
410    '(?:([+-]?[0-9]+) ?years? ?)?'
411    '(?:([+-]?[0-9]+) ?mons? ?)?'
412    '(?:([+-]?[0-9]+) ?days? ?)?'
413    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
414
415_re_interval_postgres_verbose = regex(
416    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
417    '(?:([+-]?[0-9]+) ?mons? ?)?'
418    '(?:([+-]?[0-9]+) ?days? ?)?'
419    '(?:([+-]?[0-9]+) ?hours? ?)?'
420    '(?:([+-]?[0-9]+) ?mins? ?)?'
421    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
422
423_re_interval_iso_8601 = regex(
424    'P(?:([+-]?[0-9]+)Y)?'
425    '(?:([+-]?[0-9]+)M)?'
426    '(?:([+-]?[0-9]+)D)?'
427    '(?:T(?:([+-]?[0-9]+)H)?'
428    '(?:([+-]?[0-9]+)M)?'
429    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
430
431
432def cast_interval(value):
433    """Cast an interval value."""
434    # The output format depends on the server setting IntervalStyle, but it's
435    # not necessary to consult this setting to parse it.  It's faster to just
436    # check all possible formats, and there is no ambiguity here.
437    m = _re_interval_iso_8601.match(value)
438    if m:
439        m = [d or '0' for d in m.groups()]
440        secs_ago = m.pop(5) == '-'
441        m = [int(d) for d in m]
442        years, mons, days, hours, mins, secs, usecs = m
443        if secs_ago:
444            secs = -secs
445            usecs = -usecs
446    else:
447        m = _re_interval_postgres_verbose.match(value)
448        if m:
449            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
450            secs_ago = m.pop(5) == '-'
451            m = [-int(d) for d in m] if ago else [int(d) for d in m]
452            years, mons, days, hours, mins, secs, usecs = m
453            if secs_ago:
454                secs = - secs
455                usecs = -usecs
456        else:
457            m = _re_interval_postgres.match(value)
458            if m and any(m.groups()):
459                m = [d or '0' for d in m.groups()]
460                hours_ago = m.pop(3) == '-'
461                m = [int(d) for d in m]
462                years, mons, days, hours, mins, secs, usecs = m
463                if hours_ago:
464                    hours = -hours
465                    mins = -mins
466                    secs = -secs
467                    usecs = -usecs
468            else:
469                m = _re_interval_sql_standard.match(value)
470                if m and any(m.groups()):
471                    m = [d or '0' for d in m.groups()]
472                    years_ago = m.pop(0) == '-'
473                    hours_ago = m.pop(3) == '-'
474                    m = [int(d) for d in m]
475                    years, mons, days, hours, mins, secs, usecs = m
476                    if years_ago:
477                        years = -years
478                        mons = -mons
479                    if hours_ago:
480                        hours = -hours
481                        mins = -mins
482                        secs = -secs
483                        usecs = -usecs
484                else:
485                    raise ValueError('Cannot parse interval: %s' % value)
486    days += 365 * years + 30 * mons
487    return timedelta(days=days, hours=hours, minutes=mins,
488        seconds=secs, microseconds=usecs)
489
490
491class Typecasts(dict):
492    """Dictionary mapping database types to typecast functions.
493
494    The cast functions get passed the string representation of a value in
495    the database which they need to convert to a Python object.  The
496    passed string will never be None since NULL values are already be
497    handled before the cast function is called.
498    """
499
500    # the default cast functions
501    # (str functions are ignored but have been added for faster access)
502    defaults = {'char': str, 'bpchar': str, 'name': str,
503        'text': str, 'varchar': str,
504        'bool': cast_bool, 'bytea': unescape_bytea,
505        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
506        'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode,
507        'float4': float, 'float8': float,
508        'numeric': Decimal, 'money': cast_money,
509        'date': cast_date, 'interval': cast_interval,
510        'time': cast_time, 'timetz': cast_timetz,
511        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
512        'int2vector': cast_int2vector, 'uuid': Uuid,
513        'anyarray': cast_array, 'record': cast_record}
514
515    connection = None  # will be set in local connection specific instances
516
517    def __missing__(self, typ):
518        """Create a cast function if it is not cached.
519
520        Note that this class never raises a KeyError,
521        but returns None when no special cast function exists.
522        """
523        if not isinstance(typ, str):
524            raise TypeError('Invalid type: %s' % typ)
525        cast = self.defaults.get(typ)
526        if cast:
527            # store default for faster access
528            cast = self._add_connection(cast)
529            self[typ] = cast
530        elif typ.startswith('_'):
531            # create array cast
532            base_cast = self[typ[1:]]
533            cast = self.create_array_cast(base_cast)
534            if base_cast:
535                # store only if base type exists
536                self[typ] = cast
537        return cast
538
539    @staticmethod
540    def _needs_connection(func):
541        """Check if a typecast function needs a connection argument."""
542        try:
543            args = get_args(func)
544        except (TypeError, ValueError):
545            return False
546        else:
547            return 'connection' in args[1:]
548
549    def _add_connection(self, cast):
550        """Add a connection argument to the typecast function if necessary."""
551        if not self.connection or not self._needs_connection(cast):
552            return cast
553        return partial(cast, connection=self.connection)
554
555    def get(self, typ, default=None):
556        """Get the typecast function for the given database type."""
557        return self[typ] or default
558
559    def set(self, typ, cast):
560        """Set a typecast function for the specified database type(s)."""
561        if isinstance(typ, basestring):
562            typ = [typ]
563        if cast is None:
564            for t in typ:
565                self.pop(t, None)
566                self.pop('_%s' % t, None)
567        else:
568            if not callable(cast):
569                raise TypeError("Cast parameter must be callable")
570            for t in typ:
571                self[t] = self._add_connection(cast)
572                self.pop('_%s' % t, None)
573
574    def reset(self, typ=None):
575        """Reset the typecasts for the specified type(s) to their defaults.
576
577        When no type is specified, all typecasts will be reset.
578        """
579        defaults = self.defaults
580        if typ is None:
581            self.clear()
582            self.update(defaults)
583        else:
584            if isinstance(typ, basestring):
585                typ = [typ]
586            for t in typ:
587                cast = defaults.get(t)
588                if cast:
589                    self[t] = self._add_connection(cast)
590                    t = '_%s' % t
591                    cast = defaults.get(t)
592                    if cast:
593                        self[t] = self._add_connection(cast)
594                    else:
595                        self.pop(t, None)
596                else:
597                    self.pop(t, None)
598                    self.pop('_%s' % t, None)
599
600    def create_array_cast(self, basecast):
601        """Create an array typecast for the given base cast."""
602        cast_array = self['anyarray']
603        def cast(v):
604            return cast_array(v, basecast)
605        return cast
606
607    def create_record_cast(self, name, fields, casts):
608        """Create a named record typecast for the given fields and casts."""
609        cast_record = self['record']
610        record = namedtuple(name, fields)
611        def cast(v):
612            return record(*cast_record(v, casts))
613        return cast
614
615
616_typecasts = Typecasts()  # this is the global typecast dictionary
617
618
619def get_typecast(typ):
620    """Get the global typecast function for the given database type(s)."""
621    return _typecasts.get(typ)
622
623
624def set_typecast(typ, cast):
625    """Set a global typecast function for the given database type(s).
626
627    Note that connections cache cast functions. To be sure a global change
628    is picked up by a running connection, call con.type_cache.reset_typecast().
629    """
630    _typecasts.set(typ, cast)
631
632
633def reset_typecast(typ=None):
634    """Reset the global typecasts for the given type(s) to their default.
635
636    When no type is specified, all typecasts will be reset.
637
638    Note that connections cache cast functions. To be sure a global change
639    is picked up by a running connection, call con.type_cache.reset_typecast().
640    """
641    _typecasts.reset(typ)
642
643
644class LocalTypecasts(Typecasts):
645    """Map typecasts, including local composite types, to cast functions."""
646
647    defaults = _typecasts
648
649    connection = None  # will be set in a connection specific instance
650
651    def __missing__(self, typ):
652        """Create a cast function if it is not cached."""
653        if typ.startswith('_'):
654            base_cast = self[typ[1:]]
655            cast = self.create_array_cast(base_cast)
656            if base_cast:
657                self[typ] = cast
658        else:
659            cast = self.defaults.get(typ)
660            if cast:
661                cast = self._add_connection(cast)
662                self[typ] = cast
663            else:
664                fields = self.get_fields(typ)
665                if fields:
666                    casts = [self[field.type] for field in fields]
667                    fields = [field.name for field in fields]
668                    cast = self.create_record_cast(typ, fields, casts)
669                    self[typ] = cast
670        return cast
671
672    def get_fields(self, typ):
673        """Return the fields for the given record type.
674
675        This method will be replaced with a method that looks up the fields
676        using the type cache of the connection.
677        """
678        return []
679
680
681class TypeCode(str):
682    """Class representing the type_code used by the DB-API 2.0.
683
684    TypeCode objects are strings equal to the PostgreSQL type name,
685    but carry some additional information.
686    """
687
688    @classmethod
689    def create(cls, oid, name, len, type, category, delim, relid):
690        """Create a type code for a PostgreSQL data type."""
691        self = cls(name)
692        self.oid = oid
693        self.len = len
694        self.type = type
695        self.category = category
696        self.delim = delim
697        self.relid = relid
698        return self
699
700FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
701
702
703class TypeCache(dict):
704    """Cache for database types.
705
706    This cache maps type OIDs and names to TypeCode strings containing
707    important information on the associated database type.
708    """
709
710    def __init__(self, cnx):
711        """Initialize type cache for connection."""
712        super(TypeCache, self).__init__()
713        self._escape_string = cnx.escape_string
714        self._src = cnx.source()
715        self._typecasts = LocalTypecasts()
716        self._typecasts.get_fields = self.get_fields
717        self._typecasts.connection = cnx
718        if cnx.server_version < 80400:
719            # older remote databases (not officially supported)
720            self._query_pg_type = ("SELECT oid, typname,"
721                " typlen, typtype, null as typcategory, typdelim, typrelid"
722                " FROM pg_type WHERE oid=%s")
723        else:
724            self._query_pg_type = ("SELECT oid, typname,"
725                " typlen, typtype, typcategory, typdelim, typrelid"
726                " FROM pg_type WHERE oid=%s")
727
728    def __missing__(self, key):
729        """Get the type info from the database if it is not cached."""
730        if isinstance(key, int):
731            oid = key
732        else:
733            if '.' not in key and '"' not in key:
734                key = '"%s"' % (key,)
735            oid = "'%s'::regtype" % (self._escape_string(key),)
736        try:
737            self._src.execute(self._query_pg_type % (oid,))
738        except ProgrammingError:
739            res = None
740        else:
741            res = self._src.fetch(1)
742        if not res:
743            raise KeyError('Type %s could not be found' % (key,))
744        res = res[0]
745        type_code = TypeCode.create(int(res[0]), res[1],
746            int(res[2]), res[3], res[4], res[5], int(res[6]))
747        self[type_code.oid] = self[str(type_code)] = type_code
748        return type_code
749
750    def get(self, key, default=None):
751        """Get the type even if it is not cached."""
752        try:
753            return self[key]
754        except KeyError:
755            return default
756
757    def get_fields(self, typ):
758        """Get the names and types of the fields of composite types."""
759        if not isinstance(typ, TypeCode):
760            typ = self.get(typ)
761            if not typ:
762                return None
763        if not typ.relid:
764            return None  # this type is not composite
765        self._src.execute("SELECT attname, atttypid"
766            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
767            " AND NOT attisdropped ORDER BY attnum" % (typ.relid,))
768        return [FieldInfo(name, self.get(int(oid)))
769            for name, oid in self._src.fetch(-1)]
770
771    def get_typecast(self, typ):
772        """Get the typecast function for the given database type."""
773        return self._typecasts.get(typ)
774
775    def set_typecast(self, typ, cast):
776        """Set a typecast function for the specified database type(s)."""
777        self._typecasts.set(typ, cast)
778
779    def reset_typecast(self, typ=None):
780        """Reset the typecast function for the specified database type(s)."""
781        self._typecasts.reset(typ)
782
783    def typecast(self, value, typ):
784        """Cast the given value according to the given database type."""
785        if value is None:
786            # for NULL values, no typecast is necessary
787            return None
788        cast = self.get_typecast(typ)
789        if not cast or cast is str:
790            # no typecast is necessary
791            return value
792        return cast(value)
793
794
795class _quotedict(dict):
796    """Dictionary with auto quoting of its items.
797
798    The quote attribute must be set to the desired quote function.
799    """
800
801    def __getitem__(self, key):
802        return self.quote(super(_quotedict, self).__getitem__(key))
803
804
805### Error Messages
806
807def _db_error(msg, cls=DatabaseError):
808    """Return DatabaseError with empty sqlstate attribute."""
809    error = cls(msg)
810    error.sqlstate = None
811    return error
812
813
814def _op_error(msg):
815    """Return OperationalError."""
816    return _db_error(msg, OperationalError)
817
818
819### Row Tuples
820
821_re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$')
822
823# The result rows for database operations are returned as named tuples
824# by default. Since creating namedtuple classes is a somewhat expensive
825# operation, we cache up to 1024 of these classes by default.
826
827@lru_cache(maxsize=1024)
828def _row_factory(names):
829    """Get a namedtuple factory for row results with the given names."""
830    try:
831        try:
832            return namedtuple('Row', names, rename=True)._make
833        except TypeError:  # Python 2.6 and 3.0 do not support rename
834            names = [v if _re_fieldname.match(v) and not iskeyword(v)
835                        else 'column_%d' % (n,)
836                     for n, v in enumerate(names)]
837            return namedtuple('Row', names)._make
838    except ValueError:  # there is still a problem with the field names
839        names = ['column_%d' % (n,) for n in range(len(names))]
840        return namedtuple('Row', names)._make
841
842
843def set_row_factory_size(maxsize):
844    """Change the size of the namedtuple factory cache.
845
846    If maxsize is set to None, the cache can grow without bound.
847    """
848    global _row_factory
849    _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__)
850
851
852### Cursor Object
853
854class Cursor(object):
855    """Cursor object."""
856
857    def __init__(self, dbcnx):
858        """Create a cursor object for the database connection."""
859        self.connection = self._dbcnx = dbcnx
860        self._cnx = dbcnx._cnx
861        self.type_cache = dbcnx.type_cache
862        self._src = self._cnx.source()
863        # the official attribute for describing the result columns
864        self._description = None
865        if self.row_factory is Cursor.row_factory:
866            # the row factory needs to be determined dynamically
867            self.row_factory = None
868        else:
869            self.build_row_factory = None
870        self.rowcount = -1
871        self.arraysize = 1
872        self.lastrowid = None
873
874    def __iter__(self):
875        """Make cursor compatible to the iteration protocol."""
876        return self
877
878    def __enter__(self):
879        """Enter the runtime context for the cursor object."""
880        return self
881
882    def __exit__(self, et, ev, tb):
883        """Exit the runtime context for the cursor object."""
884        self.close()
885
886    def _quote(self, value):
887        """Quote value depending on its type."""
888        if value is None:
889            return 'NULL'
890        if isinstance(value, (Hstore, Json)):
891            value = str(value)
892        if isinstance(value, basestring):
893            if isinstance(value, Binary):
894                value = self._cnx.escape_bytea(value)
895                if bytes is not str:  # Python >= 3.0
896                    value = value.decode('ascii')
897            else:
898                value = self._cnx.escape_string(value)
899            return "'%s'" % (value,)
900        if isinstance(value, float):
901            if isinf(value):
902                return "'-Infinity'" if value < 0 else "'Infinity'"
903            if isnan(value):
904                return "'NaN'"
905            return value
906        if isinstance(value, (int, long, Decimal, Literal)):
907            return value
908        if isinstance(value, datetime):
909            if value.tzinfo:
910                return "'%s'::timestamptz" % (value,)
911            return "'%s'::timestamp" % (value,)
912        if isinstance(value, date):
913            return "'%s'::date" % (value,)
914        if isinstance(value, time):
915            if value.tzinfo:
916                return "'%s'::timetz" % (value,)
917            return "'%s'::time" % value
918        if isinstance(value, timedelta):
919            return "'%s'::interval" % (value,)
920        if isinstance(value, Uuid):
921            return "'%s'::uuid" % (value,)
922        if isinstance(value, list):
923            # Quote value as an ARRAY constructor. This is better than using
924            # an array literal because it carries the information that this is
925            # an array and not a string.  One issue with this syntax is that
926            # you need to add an explicit typecast when passing empty arrays.
927            # The ARRAY keyword is actually only necessary at the top level.
928            if not value:  # exception for empty array
929                return "'{}'"
930            q = self._quote
931            try:
932                return 'ARRAY[%s]' % (','.join(str(q(v)) for v in value),)
933            except UnicodeEncodeError:  # Python 2 with non-ascii values
934                return u'ARRAY[%s]' % (','.join(unicode(q(v)) for v in value),)
935        if isinstance(value, tuple):
936            # Quote as a ROW constructor.  This is better than using a record
937            # literal because it carries the information that this is a record
938            # and not a string.  We don't use the keyword ROW in order to make
939            # this usable with the IN syntax as well.  It is only necessary
940            # when the records has a single column which is not really useful.
941            q = self._quote
942            try:
943                return '(%s)' % (','.join(str(q(v)) for v in value),)
944            except UnicodeEncodeError:  # Python 2 with non-ascii values
945                return u'(%s)' % (','.join(unicode(q(v)) for v in value),)
946        try:
947            value = value.__pg_repr__()
948        except AttributeError:
949            raise InterfaceError(
950                'Do not know how to adapt type %s' % (type(value),))
951        if isinstance(value, (tuple, list)):
952            value = self._quote(value)
953        return value
954
955    def _quoteparams(self, string, parameters):
956        """Quote parameters.
957
958        This function works for both mappings and sequences.
959
960        The function should be used even when there are no parameters,
961        so that we have a consistent behavior regarding percent signs.
962        """
963        if not parameters:
964            try:
965                return string % ()  # unescape literal quotes if possible
966            except (TypeError, ValueError):
967                return string  # silently accept unescaped quotes
968        if isinstance(parameters, dict):
969            parameters = _quotedict(parameters)
970            parameters.quote = self._quote
971        else:
972            parameters = tuple(map(self._quote, parameters))
973        return string % parameters
974
975    def _make_description(self, info):
976        """Make the description tuple for the given field info."""
977        name, typ, size, mod = info[1:]
978        type_code = self.type_cache[typ]
979        if mod > 0:
980            mod -= 4
981        if type_code == 'numeric':
982            precision, scale = mod >> 16, mod & 0xffff
983            size = precision
984        else:
985            if not size:
986                size = type_code.size
987            if size == -1:
988                size = mod
989            precision = scale = None
990        return CursorDescription(name, type_code,
991            None, size, precision, scale, None)
992
993    @property
994    def description(self):
995        """Read-only attribute describing the result columns."""
996        descr = self._description
997        if self._description is True:
998            make = self._make_description
999            descr = [make(info) for info in self._src.listinfo()]
1000            self._description = descr
1001        return descr
1002
1003    @property
1004    def colnames(self):
1005        """Unofficial convenience method for getting the column names."""
1006        return [d[0] for d in self.description]
1007
1008    @property
1009    def coltypes(self):
1010        """Unofficial convenience method for getting the column types."""
1011        return [d[1] for d in self.description]
1012
1013    def close(self):
1014        """Close the cursor object."""
1015        self._src.close()
1016
1017    def execute(self, operation, parameters=None):
1018        """Prepare and execute a database operation (query or command)."""
1019        # The parameters may also be specified as list of tuples to e.g.
1020        # insert multiple rows in a single operation, but this kind of
1021        # usage is deprecated.  We make several plausibility checks because
1022        # tuples can also be passed with the meaning of ROW constructors.
1023        if (parameters and isinstance(parameters, list)
1024                and len(parameters) > 1
1025                and all(isinstance(p, tuple) for p in parameters)
1026                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
1027            return self.executemany(operation, parameters)
1028        else:
1029            # not a list of tuples
1030            return self.executemany(operation, [parameters])
1031
1032    def executemany(self, operation, seq_of_parameters):
1033        """Prepare operation and execute it against a parameter sequence."""
1034        if not seq_of_parameters:
1035            # don't do anything without parameters
1036            return
1037        self._description = None
1038        self.rowcount = -1
1039        # first try to execute all queries
1040        rowcount = 0
1041        sql = "BEGIN"
1042        try:
1043            if not self._dbcnx._tnx:
1044                try:
1045                    self._src.execute(sql)
1046                except DatabaseError:
1047                    raise  # database provides error message
1048                except Exception:
1049                    raise _op_error("Can't start transaction")
1050                self._dbcnx._tnx = True
1051            for parameters in seq_of_parameters:
1052                sql = operation
1053                sql = self._quoteparams(sql, parameters)
1054                rows = self._src.execute(sql)
1055                if rows:  # true if not DML
1056                    rowcount += rows
1057                else:
1058                    self.rowcount = -1
1059        except DatabaseError:
1060            raise  # database provides error message
1061        except Error as err:
1062            raise _db_error(
1063                "Error in '%s': '%s' " % (sql, err), InterfaceError)
1064        except Exception as err:
1065            raise _op_error("Internal error in '%s': %s" % (sql, err))
1066        # then initialize result raw count and description
1067        if self._src.resulttype == RESULT_DQL:
1068            self._description = True  # fetch on demand
1069            self.rowcount = self._src.ntuples
1070            self.lastrowid = None
1071            if self.build_row_factory:
1072                self.row_factory = self.build_row_factory()
1073        else:
1074            self.rowcount = rowcount
1075            self.lastrowid = self._src.oidstatus()
1076        # return the cursor object, so you can write statements such as
1077        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
1078        return self
1079
1080    def fetchone(self):
1081        """Fetch the next row of a query result set."""
1082        res = self.fetchmany(1, False)
1083        try:
1084            return res[0]
1085        except IndexError:
1086            return None
1087
1088    def fetchall(self):
1089        """Fetch all (remaining) rows of a query result."""
1090        return self.fetchmany(-1, False)
1091
1092    def fetchmany(self, size=None, keep=False):
1093        """Fetch the next set of rows of a query result.
1094
1095        The number of rows to fetch per call is specified by the
1096        size parameter. If it is not given, the cursor's arraysize
1097        determines the number of rows to be fetched. If you set
1098        the keep parameter to true, this is kept as new arraysize.
1099        """
1100        if size is None:
1101            size = self.arraysize
1102        if keep:
1103            self.arraysize = size
1104        try:
1105            result = self._src.fetch(size)
1106        except DatabaseError:
1107            raise
1108        except Error as err:
1109            raise _db_error(str(err))
1110        typecast = self.type_cache.typecast
1111        return [self.row_factory([typecast(value, typ)
1112            for typ, value in zip(self.coltypes, row)]) for row in result]
1113
1114    def callproc(self, procname, parameters=None):
1115        """Call a stored database procedure with the given name.
1116
1117        The sequence of parameters must contain one entry for each input
1118        argument that the procedure expects. The result of the call is the
1119        same as this input sequence; replacement of output and input/output
1120        parameters in the return value is currently not supported.
1121
1122        The procedure may also provide a result set as output. These can be
1123        requested through the standard fetch methods of the cursor.
1124        """
1125        n = parameters and len(parameters) or 0
1126        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
1127        self.execute(query, parameters)
1128        return parameters
1129
1130    def copy_from(self, stream, table,
1131            format=None, sep=None, null=None, size=None, columns=None):
1132        """Copy data from an input stream to the specified table.
1133
1134        The input stream can be a file-like object with a read() method or
1135        it can also be an iterable returning a row or multiple rows of input
1136        on each iteration.
1137
1138        The format must be text, csv or binary. The sep option sets the
1139        column separator (delimiter) used in the non binary formats.
1140        The null option sets the textual representation of NULL in the input.
1141
1142        The size option sets the size of the buffer used when reading data
1143        from file-like objects.
1144
1145        The copy operation can be restricted to a subset of columns. If no
1146        columns are specified, all of them will be copied.
1147        """
1148        binary_format = format == 'binary'
1149        try:
1150            read = stream.read
1151        except AttributeError:
1152            if size:
1153                raise ValueError("Size must only be set for file-like objects")
1154            if binary_format:
1155                input_type = bytes
1156                type_name = 'byte strings'
1157            else:
1158                input_type = basestring
1159                type_name = 'strings'
1160
1161            if isinstance(stream, basestring):
1162                if not isinstance(stream, input_type):
1163                    raise ValueError("The input must be %s" % (type_name,))
1164                if not binary_format:
1165                    if isinstance(stream, str):
1166                        if not stream.endswith('\n'):
1167                            stream += '\n'
1168                    else:
1169                        if not stream.endswith(b'\n'):
1170                            stream += b'\n'
1171
1172                def chunks():
1173                    yield stream
1174
1175            elif isinstance(stream, Iterable):
1176
1177                def chunks():
1178                    for chunk in stream:
1179                        if not isinstance(chunk, input_type):
1180                            raise ValueError(
1181                                "Input stream must consist of %s"
1182                                % (type_name,))
1183                        if isinstance(chunk, str):
1184                            if not chunk.endswith('\n'):
1185                                chunk += '\n'
1186                        else:
1187                            if not chunk.endswith(b'\n'):
1188                                chunk += b'\n'
1189                        yield chunk
1190
1191            else:
1192                raise TypeError("Need an input stream to copy from")
1193        else:
1194            if size is None:
1195                size = 8192
1196            elif not isinstance(size, int):
1197                raise TypeError("The size option must be an integer")
1198            if size > 0:
1199
1200                def chunks():
1201                    while True:
1202                        buffer = read(size)
1203                        yield buffer
1204                        if not buffer or len(buffer) < size:
1205                            break
1206
1207            else:
1208
1209                def chunks():
1210                    yield read()
1211
1212        if not table or not isinstance(table, basestring):
1213            raise TypeError("Need a table to copy to")
1214        if table.lower().startswith('select'):
1215                raise ValueError("Must specify a table, not a query")
1216        else:
1217            table = '"%s"' % (table,)
1218        operation = ['copy %s' % (table,)]
1219        options = []
1220        params = []
1221        if format is not None:
1222            if not isinstance(format, basestring):
1223                raise TypeError("The frmat option must be be a string")
1224            if format not in ('text', 'csv', 'binary'):
1225                raise ValueError("Invalid format")
1226            options.append('format %s' % (format,))
1227        if sep is not None:
1228            if not isinstance(sep, basestring):
1229                raise TypeError("The sep option must be a string")
1230            if format == 'binary':
1231                raise ValueError(
1232                    "The sep option is not allowed with binary format")
1233            if len(sep) != 1:
1234                raise ValueError(
1235                    "The sep option must be a single one-byte character")
1236            options.append('delimiter %s')
1237            params.append(sep)
1238        if null is not None:
1239            if not isinstance(null, basestring):
1240                raise TypeError("The null option must be a string")
1241            options.append('null %s')
1242            params.append(null)
1243        if columns:
1244            if not isinstance(columns, basestring):
1245                columns = ','.join('"%s"' % (col,) for col in columns)
1246            operation.append('(%s)' % (columns,))
1247        operation.append("from stdin")
1248        if options:
1249            operation.append('(%s)' % (','.join(options),))
1250        operation = ' '.join(operation)
1251
1252        putdata = self._src.putdata
1253        self.execute(operation, params)
1254
1255        try:
1256            for chunk in chunks():
1257                putdata(chunk)
1258        except BaseException as error:
1259            self.rowcount = -1
1260            # the following call will re-raise the error
1261            putdata(error)
1262        else:
1263            self.rowcount = putdata(None)
1264
1265        # return the cursor object, so you can chain operations
1266        return self
1267
1268    def copy_to(self, stream, table,
1269            format=None, sep=None, null=None, decode=None, columns=None):
1270        """Copy data from the specified table to an output stream.
1271
1272        The output stream can be a file-like object with a write() method or
1273        it can also be None, in which case the method will return a generator
1274        yielding a row on each iteration.
1275
1276        Output will be returned as byte strings unless you set decode to true.
1277
1278        Note that you can also use a select query instead of the table name.
1279
1280        The format must be text, csv or binary. The sep option sets the
1281        column separator (delimiter) used in the non binary formats.
1282        The null option sets the textual representation of NULL in the output.
1283
1284        The copy operation can be restricted to a subset of columns. If no
1285        columns are specified, all of them will be copied.
1286        """
1287        binary_format = format == 'binary'
1288        if stream is not None:
1289            try:
1290                write = stream.write
1291            except AttributeError:
1292                raise TypeError("Need an output stream to copy to")
1293        if not table or not isinstance(table, basestring):
1294            raise TypeError("Need a table to copy to")
1295        if table.lower().startswith('select'):
1296            if columns:
1297                raise ValueError("Columns must be specified in the query")
1298            table = '(%s)' % (table,)
1299        else:
1300            table = '"%s"' % (table,)
1301        operation = ['copy %s' % (table,)]
1302        options = []
1303        params = []
1304        if format is not None:
1305            if not isinstance(format, basestring):
1306                raise TypeError("The format option must be a string")
1307            if format not in ('text', 'csv', 'binary'):
1308                raise ValueError("Invalid format")
1309            options.append('format %s' % (format,))
1310        if sep is not None:
1311            if not isinstance(sep, basestring):
1312                raise TypeError("The sep option must be a string")
1313            if binary_format:
1314                raise ValueError(
1315                    "The sep option is not allowed with binary format")
1316            if len(sep) != 1:
1317                raise ValueError(
1318                    "The sep option must be a single one-byte character")
1319            options.append('delimiter %s')
1320            params.append(sep)
1321        if null is not None:
1322            if not isinstance(null, basestring):
1323                raise TypeError("The null option must be a string")
1324            options.append('null %s')
1325            params.append(null)
1326        if decode is None:
1327            if format == 'binary':
1328                decode = False
1329            else:
1330                decode = str is unicode
1331        else:
1332            if not isinstance(decode, (int, bool)):
1333                raise TypeError("The decode option must be a boolean")
1334            if decode and binary_format:
1335                raise ValueError(
1336                    "The decode option is not allowed with binary format")
1337        if columns:
1338            if not isinstance(columns, basestring):
1339                columns = ','.join('"%s"' % (col,) for col in columns)
1340            operation.append('(%s)' % (columns,))
1341
1342        operation.append("to stdout")
1343        if options:
1344            operation.append('(%s)' % (','.join(options),))
1345        operation = ' '.join(operation)
1346
1347        getdata = self._src.getdata
1348        self.execute(operation, params)
1349
1350        def copy():
1351            self.rowcount = 0
1352            while True:
1353                row = getdata(decode)
1354                if isinstance(row, int):
1355                    if self.rowcount != row:
1356                        self.rowcount = row
1357                    break
1358                self.rowcount += 1
1359                yield row
1360
1361        if stream is None:
1362            # no input stream, return the generator
1363            return copy()
1364
1365        # write the rows to the file-like input stream
1366        for row in copy():
1367            write(row)
1368
1369        # return the cursor object, so you can chain operations
1370        return self
1371
1372    def __next__(self):
1373        """Return the next row (support for the iteration protocol)."""
1374        res = self.fetchone()
1375        if res is None:
1376            raise StopIteration
1377        return res
1378
1379    # Note that since Python 2.6 the iterator protocol uses __next()__
1380    # instead of next(), we keep it only for backward compatibility of pgdb.
1381    next = __next__
1382
1383    @staticmethod
1384    def nextset():
1385        """Not supported."""
1386        raise NotSupportedError("The nextset() method is not supported")
1387
1388    @staticmethod
1389    def setinputsizes(sizes):
1390        """Not supported."""
1391        pass  # unsupported, but silently passed
1392
1393    @staticmethod
1394    def setoutputsize(size, column=0):
1395        """Not supported."""
1396        pass  # unsupported, but silently passed
1397
1398    @staticmethod
1399    def row_factory(row):
1400        """Process rows before they are returned.
1401
1402        You can overwrite this statically with a custom row factory, or
1403        you can build a row factory dynamically with build_row_factory().
1404
1405        For example, you can create a Cursor class that returns rows as
1406        Python dictionaries like this:
1407
1408            class DictCursor(pgdb.Cursor):
1409
1410                def row_factory(self, row):
1411                    return {desc[0]: value
1412                        for desc, value in zip(self.description, row)}
1413
1414            cur = DictCursor(con)  # get one DictCursor instance or
1415            con.cursor_type = DictCursor  # always use DictCursor instances
1416        """
1417        raise NotImplementedError
1418
1419    def build_row_factory(self):
1420        """Build a row factory based on the current description.
1421
1422        This implementation builds a row factory for creating named tuples.
1423        You can overwrite this method if you want to dynamically create
1424        different row factories whenever the column description changes.
1425        """
1426        names = self.colnames
1427        if names:
1428            return _row_factory(tuple(names))
1429
1430
1431
1432CursorDescription = namedtuple('CursorDescription',
1433    ['name', 'type_code', 'display_size', 'internal_size',
1434     'precision', 'scale', 'null_ok'])
1435
1436
1437### Connection Objects
1438
1439class Connection(object):
1440    """Connection object."""
1441
1442    # expose the exceptions as attributes on the connection object
1443    Error = Error
1444    Warning = Warning
1445    InterfaceError = InterfaceError
1446    DatabaseError = DatabaseError
1447    InternalError = InternalError
1448    OperationalError = OperationalError
1449    ProgrammingError = ProgrammingError
1450    IntegrityError = IntegrityError
1451    DataError = DataError
1452    NotSupportedError = NotSupportedError
1453
1454    def __init__(self, cnx):
1455        """Create a database connection object."""
1456        self._cnx = cnx  # connection
1457        self._tnx = False  # transaction state
1458        self.type_cache = TypeCache(cnx)
1459        self.cursor_type = Cursor
1460        try:
1461            self._cnx.source()
1462        except Exception:
1463            raise _op_error("Invalid connection")
1464
1465    def __enter__(self):
1466        """Enter the runtime context for the connection object.
1467
1468        The runtime context can be used for running transactions.
1469        """
1470        return self
1471
1472    def __exit__(self, et, ev, tb):
1473        """Exit the runtime context for the connection object.
1474
1475        This does not close the connection, but it ends a transaction.
1476        """
1477        if et is None and ev is None and tb is None:
1478            self.commit()
1479        else:
1480            self.rollback()
1481
1482    def close(self):
1483        """Close the connection object."""
1484        if self._cnx:
1485            if self._tnx:
1486                try:
1487                    self.rollback()
1488                except DatabaseError:
1489                    pass
1490            self._cnx.close()
1491            self._cnx = None
1492        else:
1493            raise _op_error("Connection has been closed")
1494
1495    @property
1496    def closed(self):
1497        """Check whether the connection has been closed or is broken."""
1498        try:
1499            return not self._cnx or self._cnx.status != 1
1500        except TypeError:
1501            return True
1502
1503    def commit(self):
1504        """Commit any pending transaction to the database."""
1505        if self._cnx:
1506            if self._tnx:
1507                self._tnx = False
1508                try:
1509                    self._cnx.source().execute("COMMIT")
1510                except DatabaseError:
1511                    raise
1512                except Exception:
1513                    raise _op_error("Can't commit")
1514        else:
1515            raise _op_error("Connection has been closed")
1516
1517    def rollback(self):
1518        """Roll back to the start of any pending transaction."""
1519        if self._cnx:
1520            if self._tnx:
1521                self._tnx = False
1522                try:
1523                    self._cnx.source().execute("ROLLBACK")
1524                except DatabaseError:
1525                    raise
1526                except Exception:
1527                    raise _op_error("Can't rollback")
1528        else:
1529            raise _op_error("Connection has been closed")
1530
1531    def cursor(self):
1532        """Return a new cursor object using the connection."""
1533        if self._cnx:
1534            try:
1535                return self.cursor_type(self)
1536            except Exception:
1537                raise _op_error("Invalid connection")
1538        else:
1539            raise _op_error("Connection has been closed")
1540
1541    if shortcutmethods:  # otherwise do not implement and document this
1542
1543        def execute(self, operation, params=None):
1544            """Shortcut method to run an operation on an implicit cursor."""
1545            cursor = self.cursor()
1546            cursor.execute(operation, params)
1547            return cursor
1548
1549        def executemany(self, operation, param_seq):
1550            """Shortcut method to run an operation against a sequence."""
1551            cursor = self.cursor()
1552            cursor.executemany(operation, param_seq)
1553            return cursor
1554
1555
1556### Module Interface
1557
1558_connect = connect
1559
1560def connect(dsn=None,
1561        user=None, password=None,
1562        host=None, database=None, **kwargs):
1563    """Connect to a database."""
1564    # first get params from DSN
1565    dbport = -1
1566    dbhost = ""
1567    dbname = ""
1568    dbuser = ""
1569    dbpasswd = ""
1570    dbopt = ""
1571    try:
1572        params = dsn.split(":")
1573        dbhost = params[0]
1574        dbname = params[1]
1575        dbuser = params[2]
1576        dbpasswd = params[3]
1577        dbopt = params[4]
1578    except (AttributeError, IndexError, TypeError):
1579        pass
1580
1581    # override if necessary
1582    if user is not None:
1583        dbuser = user
1584    if password is not None:
1585        dbpasswd = password
1586    if database is not None:
1587        dbname = database
1588    if host is not None:
1589        try:
1590            params = host.split(":")
1591            dbhost = params[0]
1592            dbport = int(params[1])
1593        except (AttributeError, IndexError, TypeError, ValueError):
1594            pass
1595
1596    # empty host is localhost
1597    if dbhost == "":
1598        dbhost = None
1599    if dbuser == "":
1600        dbuser = None
1601
1602    # pass keyword arguments as connection info string
1603    if kwargs:
1604        kwargs = list(kwargs.items())
1605        if '=' in dbname:
1606            dbname = [dbname]
1607        else:
1608            kwargs.insert(0, ('dbname', dbname))
1609            dbname = []
1610        for kw, value in kwargs:
1611            value = str(value)
1612            if not value or ' ' in value:
1613                value = "'%s'" % (value.replace(
1614                    "'", "\\'").replace('\\', '\\\\'),)
1615            dbname.append('%s=%s' % (kw, value))
1616        dbname = ' '.join(dbname)
1617
1618    # open the connection
1619    cnx = _connect(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd)
1620    return Connection(cnx)
1621
1622
1623### Types Handling
1624
1625class Type(frozenset):
1626    """Type class for a couple of PostgreSQL data types.
1627
1628    PostgreSQL is object-oriented: types are dynamic.
1629    We must thus use type names as internal type codes.
1630    """
1631
1632    def __new__(cls, values):
1633        if isinstance(values, basestring):
1634            values = values.split()
1635        return super(Type, cls).__new__(cls, values)
1636
1637    def __eq__(self, other):
1638        if isinstance(other, basestring):
1639            if other.startswith('_'):
1640                other = other[1:]
1641            return other in self
1642        else:
1643            return super(Type, self).__eq__(other)
1644
1645    def __ne__(self, other):
1646        if isinstance(other, basestring):
1647            if other.startswith('_'):
1648                other = other[1:]
1649            return other not in self
1650        else:
1651            return super(Type, self).__ne__(other)
1652
1653
1654class ArrayType:
1655    """Type class for PostgreSQL array types."""
1656
1657    def __eq__(self, other):
1658        if isinstance(other, basestring):
1659            return other.startswith('_')
1660        else:
1661            return isinstance(other, ArrayType)
1662
1663    def __ne__(self, other):
1664        if isinstance(other, basestring):
1665            return not other.startswith('_')
1666        else:
1667            return not isinstance(other, ArrayType)
1668
1669
1670class RecordType:
1671    """Type class for PostgreSQL record types."""
1672
1673    def __eq__(self, other):
1674        if isinstance(other, TypeCode):
1675            return other.type == 'c'
1676        elif isinstance(other, basestring):
1677            return other == 'record'
1678        else:
1679            return isinstance(other, RecordType)
1680
1681    def __ne__(self, other):
1682        if isinstance(other, TypeCode):
1683            return other.type != 'c'
1684        elif isinstance(other, basestring):
1685            return other != 'record'
1686        else:
1687            return not isinstance(other, RecordType)
1688
1689
1690# Mandatory type objects defined by DB-API 2 specs:
1691
1692STRING = Type('char bpchar name text varchar')
1693BINARY = Type('bytea')
1694NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1695DATETIME = Type('date time timetz timestamp timestamptz interval'
1696    ' abstime reltime')  # these are very old
1697ROWID = Type('oid')
1698
1699
1700# Additional type objects (more specific):
1701
1702BOOL = Type('bool')
1703SMALLINT = Type('int2')
1704INTEGER = Type('int2 int4 int8 serial')
1705LONG = Type('int8')
1706FLOAT = Type('float4 float8')
1707NUMERIC = Type('numeric')
1708MONEY = Type('money')
1709DATE = Type('date')
1710TIME = Type('time timetz')
1711TIMESTAMP = Type('timestamp timestamptz')
1712INTERVAL = Type('interval')
1713UUID = Type('uuid')
1714HSTORE = Type('hstore')
1715JSON = Type('json jsonb')
1716
1717# Type object for arrays (also equate to their base types):
1718
1719ARRAY = ArrayType()
1720
1721# Type object for records (encompassing all composite types):
1722
1723RECORD = RecordType()
1724
1725
1726# Mandatory type helpers defined by DB-API 2 specs:
1727
1728def Date(year, month, day):
1729    """Construct an object holding a date value."""
1730    return date(year, month, day)
1731
1732
1733def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None):
1734    """Construct an object holding a time value."""
1735    return time(hour, minute, second, microsecond, tzinfo)
1736
1737
1738def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0,
1739        tzinfo=None):
1740    """Construct an object holding a time stamp value."""
1741    return datetime(year, month, day, hour, minute, second, microsecond, tzinfo)
1742
1743
1744def DateFromTicks(ticks):
1745    """Construct an object holding a date value from the given ticks value."""
1746    return Date(*localtime(ticks)[:3])
1747
1748
1749def TimeFromTicks(ticks):
1750    """Construct an object holding a time value from the given ticks value."""
1751    return Time(*localtime(ticks)[3:6])
1752
1753
1754def TimestampFromTicks(ticks):
1755    """Construct an object holding a time stamp from the given ticks value."""
1756    return Timestamp(*localtime(ticks)[:6])
1757
1758
1759class Binary(bytes):
1760    """Construct an object capable of holding a binary (long) string value."""
1761
1762
1763# Additional type helpers for PyGreSQL:
1764
1765def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0):
1766    """Construct an object holding a time inverval value."""
1767    return timedelta(days, hours=hours, minutes=minutes, seconds=seconds,
1768        microseconds=microseconds)
1769
1770
1771Uuid = Uuid  # Construct an object holding a UUID value
1772
1773
1774class Hstore(dict):
1775    """Wrapper class for marking hstore values."""
1776
1777    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
1778    _re_escape = regex(r'(["\\])')
1779
1780    @classmethod
1781    def _quote(cls, s):
1782        if s is None:
1783            return 'NULL'
1784        if not s:
1785            return '""'
1786        quote = cls._re_quote.search(s)
1787        s = cls._re_escape.sub(r'\\\1', s)
1788        if quote:
1789            s = '"%s"' % (s,)
1790        return s
1791
1792    def __str__(self):
1793        q = self._quote
1794        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
1795
1796
1797class Json:
1798    """Construct a wrapper for holding an object serializable to JSON."""
1799
1800    def __init__(self, obj, encode=None):
1801        self.obj = obj
1802        self.encode = encode or jsonencode
1803
1804    def __str__(self):
1805        obj = self.obj
1806        if isinstance(obj, basestring):
1807            return obj
1808        return self.encode(obj)
1809
1810
1811class Literal:
1812    """Construct a wrapper for holding a literal SQL string."""
1813
1814    def __init__(self, sql):
1815        self.sql = sql
1816
1817    def __str__(self):
1818        return self.sql
1819
1820    __pg_repr__ = __str__
1821
1822# If run as script, print some information:
1823
1824if __name__ == '__main__':
1825    print('PyGreSQL version', version)
1826    print('')
1827    print(__doc__)
Note: See TracBrowser for help on using the repository browser.