source: trunk/pgdb.py @ 973

Last change on this file since 973 was 969, checked in by cito, 4 months ago

Shebang should not be followed by a blank

It's a myth that it is needed because some old versions of Unix expect it.
However, some editors do not like a blank here.

  • Property svn:keywords set to Id
File size: 59.9 KB
Line 
1#!/usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 969 2019-04-19 13:35:23Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function, division
67
68from _pg import *
69
70__version__ = version
71
72from datetime import date, time, datetime, timedelta, tzinfo
73from time import localtime
74from decimal import Decimal
75from uuid import UUID as Uuid
76from math import isnan, isinf
77try:
78    from collections.abc import Iterable
79except ImportError:  # Python < 3.3
80    from collections import Iterable
81from collections import namedtuple
82from keyword import iskeyword
83from functools import partial
84from re import compile as regex
85from json import loads as jsondecode, dumps as jsonencode
86
87try:
88    long
89except NameError:  # Python >= 3.0
90    long = int
91
92try:
93    unicode
94except NameError:  # Python >= 3.0
95    unicode = str
96
97try:
98    basestring
99except NameError:  # Python >= 3.0
100    basestring = (str, bytes)
101
102try:
103    from functools import lru_cache
104except ImportError:  # Python < 3.2
105    from functools import update_wrapper
106    try:
107        from _thread import RLock
108    except ImportError:
109        class RLock:  # for builds without threads
110            def __enter__(self): pass
111
112            def __exit__(self, exctype, excinst, exctb): pass
113
114    def lru_cache(maxsize=128):
115        """Simplified functools.lru_cache decorator for one argument."""
116
117        def decorator(function):
118            sentinel = object()
119            cache = {}
120            get = cache.get
121            lock = RLock()
122            root = []
123            root_full = [root, False]
124            root[:] = [root, root, None, None]
125
126            if maxsize == 0:
127
128                def wrapper(arg):
129                    res = function(arg)
130                    return res
131
132            elif maxsize is None:
133
134                def wrapper(arg):
135                    res = get(arg, sentinel)
136                    if res is not sentinel:
137                        return res
138                    res = function(arg)
139                    cache[arg] = res
140                    return res
141
142            else:
143
144                def wrapper(arg):
145                    with lock:
146                        link = get(arg)
147                        if link is not None:
148                            root = root_full[0]
149                            prev, next, _arg, res = link
150                            prev[1] = next
151                            next[0] = prev
152                            last = root[0]
153                            last[1] = root[0] = link
154                            link[0] = last
155                            link[1] = root
156                            return res
157                    res = function(arg)
158                    with lock:
159                        root, full = root_full
160                        if arg in cache:
161                            pass
162                        elif full:
163                            oldroot = root
164                            oldroot[2] = arg
165                            oldroot[3] = res
166                            root = root_full[0] = oldroot[1]
167                            oldarg = root[2]
168                            oldres = root[3]  # keep reference
169                            root[2] = root[3] = None
170                            del cache[oldarg]
171                            cache[arg] = oldroot
172                        else:
173                            last = root[0]
174                            link = [last, root, arg, res]
175                            last[1] = root[0] = cache[arg] = link
176                            if len(cache) >= maxsize:
177                                root_full[1] = True
178                    return res
179
180            wrapper.__wrapped__ = function
181            return update_wrapper(wrapper, function)
182
183        return decorator
184
185
186### Module Constants
187
188# compliant with DB API 2.0
189apilevel = '2.0'
190
191# module may be shared, but not connections
192threadsafety = 1
193
194# this module use extended python format codes
195paramstyle = 'pyformat'
196
197# shortcut methods have been excluded from DB API 2 and
198# are not recommended by the DB SIG, but they can be handy
199shortcutmethods = 1
200
201
202### Internal Type Handling
203
204try:
205    from inspect import signature
206except ImportError:  # Python < 3.3
207    from inspect import getargspec
208
209    def get_args(func):
210        return getargspec(func).args
211else:
212
213    def get_args(func):
214        return list(signature(func).parameters)
215
216try:
217    from datetime import timezone
218except ImportError:  # Python < 3.2
219
220    class timezone(tzinfo):
221        """Simple timezone implementation."""
222
223        def __init__(self, offset, name=None):
224            self.offset = offset
225            if not name:
226                minutes = self.offset.days * 1440 + self.offset.seconds // 60
227                if minutes < 0:
228                    hours, minutes = divmod(-minutes, 60)
229                    hours = -hours
230                else:
231                    hours, minutes = divmod(minutes, 60)
232                name = 'UTC%+03d:%02d' % (hours, minutes)
233            self.name = name
234
235        def utcoffset(self, dt):
236            return self.offset
237
238        def tzname(self, dt):
239            return self.name
240
241        def dst(self, dt):
242            return None
243
244    timezone.utc = timezone(timedelta(0), 'UTC')
245
246    _has_timezone = False
247else:
248    _has_timezone = True
249
250# time zones used in Postgres timestamptz output
251_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
252    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
253    UCT='+0000', UTC='+0000', WET='+0000')
254
255
256def _timezone_as_offset(tz):
257    if tz.startswith(('+', '-')):
258        if len(tz) < 5:
259            return tz + '00'
260        return tz.replace(':', '')
261    return _timezones.get(tz, '+0000')
262
263
264def _get_timezone(tz):
265    tz = _timezone_as_offset(tz)
266    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
267    if tz[0] == '-':
268        minutes = -minutes
269    return timezone(timedelta(minutes=minutes), tz)
270
271
272def decimal_type(decimal_type=None):
273    """Get or set global type to be used for decimal values.
274
275    Note that connections cache cast functions. To be sure a global change
276    is picked up by a running connection, call con.type_cache.reset_typecast().
277    """
278    global Decimal
279    if decimal_type is not None:
280        Decimal = decimal_type
281        set_typecast('numeric', decimal_type)
282    return Decimal
283
284
285def cast_bool(value):
286    """Cast boolean value in database format to bool."""
287    if value:
288        return value[0] in ('t', 'T')
289
290
291def cast_money(value):
292    """Cast money value in database format to Decimal."""
293    if value:
294        value = value.replace('(', '-')
295        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
296
297
298def cast_int2vector(value):
299    """Cast an int2vector value."""
300    return [int(v) for v in value.split()]
301
302
303def cast_date(value, connection):
304    """Cast a date value."""
305    # The output format depends on the server setting DateStyle.  The default
306    # setting ISO and the setting for German are actually unambiguous.  The
307    # order of days and months in the other two settings is however ambiguous,
308    # so at least here we need to consult the setting to properly parse values.
309    if value == '-infinity':
310        return date.min
311    if value == 'infinity':
312        return date.max
313    value = value.split()
314    if value[-1] == 'BC':
315        return date.min
316    value = value[0]
317    if len(value) > 10:
318        return date.max
319    fmt = connection.date_format()
320    return datetime.strptime(value, fmt).date()
321
322
323def cast_time(value):
324    """Cast a time value."""
325    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
326    return datetime.strptime(value, fmt).time()
327
328
329_re_timezone = regex('(.*)([+-].*)')
330
331
332def cast_timetz(value):
333    """Cast a timetz value."""
334    tz = _re_timezone.match(value)
335    if tz:
336        value, tz = tz.groups()
337    else:
338        tz = '+0000'
339    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
340    if _has_timezone:
341        value += _timezone_as_offset(tz)
342        fmt += '%z'
343        return datetime.strptime(value, fmt).timetz()
344    return datetime.strptime(value, fmt).timetz().replace(
345        tzinfo=_get_timezone(tz))
346
347
348def cast_timestamp(value, connection):
349    """Cast a timestamp value."""
350    if value == '-infinity':
351        return datetime.min
352    if value == 'infinity':
353        return datetime.max
354    value = value.split()
355    if value[-1] == 'BC':
356        return datetime.min
357    fmt = connection.date_format()
358    if fmt.endswith('-%Y') and len(value) > 2:
359        value = value[1:5]
360        if len(value[3]) > 4:
361            return datetime.max
362        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
363            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
364    else:
365        if len(value[0]) > 10:
366            return datetime.max
367        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
368    return datetime.strptime(' '.join(value), ' '.join(fmt))
369
370
371def cast_timestamptz(value, connection):
372    """Cast a timestamptz value."""
373    if value == '-infinity':
374        return datetime.min
375    if value == 'infinity':
376        return datetime.max
377    value = value.split()
378    if value[-1] == 'BC':
379        return datetime.min
380    fmt = connection.date_format()
381    if fmt.endswith('-%Y') and len(value) > 2:
382        value = value[1:]
383        if len(value[3]) > 4:
384            return datetime.max
385        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
386            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
387        value, tz = value[:-1], value[-1]
388    else:
389        if fmt.startswith('%Y-'):
390            tz = _re_timezone.match(value[1])
391            if tz:
392                value[1], tz = tz.groups()
393            else:
394                tz = '+0000'
395        else:
396            value, tz = value[:-1], value[-1]
397        if len(value[0]) > 10:
398            return datetime.max
399        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
400    if _has_timezone:
401        value.append(_timezone_as_offset(tz))
402        fmt.append('%z')
403        return datetime.strptime(' '.join(value), ' '.join(fmt))
404    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
405        tzinfo=_get_timezone(tz))
406
407
408_re_interval_sql_standard = regex(
409    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
410    '(?:([+-]?[0-9]+)(?!:) ?)?'
411    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
412
413_re_interval_postgres = regex(
414    '(?:([+-]?[0-9]+) ?years? ?)?'
415    '(?:([+-]?[0-9]+) ?mons? ?)?'
416    '(?:([+-]?[0-9]+) ?days? ?)?'
417    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
418
419_re_interval_postgres_verbose = regex(
420    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
421    '(?:([+-]?[0-9]+) ?mons? ?)?'
422    '(?:([+-]?[0-9]+) ?days? ?)?'
423    '(?:([+-]?[0-9]+) ?hours? ?)?'
424    '(?:([+-]?[0-9]+) ?mins? ?)?'
425    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
426
427_re_interval_iso_8601 = regex(
428    'P(?:([+-]?[0-9]+)Y)?'
429    '(?:([+-]?[0-9]+)M)?'
430    '(?:([+-]?[0-9]+)D)?'
431    '(?:T(?:([+-]?[0-9]+)H)?'
432    '(?:([+-]?[0-9]+)M)?'
433    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
434
435
436def cast_interval(value):
437    """Cast an interval value."""
438    # The output format depends on the server setting IntervalStyle, but it's
439    # not necessary to consult this setting to parse it.  It's faster to just
440    # check all possible formats, and there is no ambiguity here.
441    m = _re_interval_iso_8601.match(value)
442    if m:
443        m = [d or '0' for d in m.groups()]
444        secs_ago = m.pop(5) == '-'
445        m = [int(d) for d in m]
446        years, mons, days, hours, mins, secs, usecs = m
447        if secs_ago:
448            secs = -secs
449            usecs = -usecs
450    else:
451        m = _re_interval_postgres_verbose.match(value)
452        if m:
453            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
454            secs_ago = m.pop(5) == '-'
455            m = [-int(d) for d in m] if ago else [int(d) for d in m]
456            years, mons, days, hours, mins, secs, usecs = m
457            if secs_ago:
458                secs = - secs
459                usecs = -usecs
460        else:
461            m = _re_interval_postgres.match(value)
462            if m and any(m.groups()):
463                m = [d or '0' for d in m.groups()]
464                hours_ago = m.pop(3) == '-'
465                m = [int(d) for d in m]
466                years, mons, days, hours, mins, secs, usecs = m
467                if hours_ago:
468                    hours = -hours
469                    mins = -mins
470                    secs = -secs
471                    usecs = -usecs
472            else:
473                m = _re_interval_sql_standard.match(value)
474                if m and any(m.groups()):
475                    m = [d or '0' for d in m.groups()]
476                    years_ago = m.pop(0) == '-'
477                    hours_ago = m.pop(3) == '-'
478                    m = [int(d) for d in m]
479                    years, mons, days, hours, mins, secs, usecs = m
480                    if years_ago:
481                        years = -years
482                        mons = -mons
483                    if hours_ago:
484                        hours = -hours
485                        mins = -mins
486                        secs = -secs
487                        usecs = -usecs
488                else:
489                    raise ValueError('Cannot parse interval: %s' % value)
490    days += 365 * years + 30 * mons
491    return timedelta(days=days, hours=hours, minutes=mins,
492        seconds=secs, microseconds=usecs)
493
494
495class Typecasts(dict):
496    """Dictionary mapping database types to typecast functions.
497
498    The cast functions get passed the string representation of a value in
499    the database which they need to convert to a Python object.  The
500    passed string will never be None since NULL values are already
501    handled before the cast function is called.
502    """
503
504    # the default cast functions
505    # (str functions are ignored but have been added for faster access)
506    defaults = {'char': str, 'bpchar': str, 'name': str,
507        'text': str, 'varchar': str,
508        'bool': cast_bool, 'bytea': unescape_bytea,
509        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
510        'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode,
511        'float4': float, 'float8': float,
512        'numeric': Decimal, 'money': cast_money,
513        'date': cast_date, 'interval': cast_interval,
514        'time': cast_time, 'timetz': cast_timetz,
515        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
516        'int2vector': cast_int2vector, 'uuid': Uuid,
517        'anyarray': cast_array, 'record': cast_record}
518
519    connection = None  # will be set in local connection specific instances
520
521    def __missing__(self, typ):
522        """Create a cast function if it is not cached.
523
524        Note that this class never raises a KeyError,
525        but returns None when no special cast function exists.
526        """
527        if not isinstance(typ, str):
528            raise TypeError('Invalid type: %s' % typ)
529        cast = self.defaults.get(typ)
530        if cast:
531            # store default for faster access
532            cast = self._add_connection(cast)
533            self[typ] = cast
534        elif typ.startswith('_'):
535            # create array cast
536            base_cast = self[typ[1:]]
537            cast = self.create_array_cast(base_cast)
538            if base_cast:
539                # store only if base type exists
540                self[typ] = cast
541        return cast
542
543    @staticmethod
544    def _needs_connection(func):
545        """Check if a typecast function needs a connection argument."""
546        try:
547            args = get_args(func)
548        except (TypeError, ValueError):
549            return False
550        else:
551            return 'connection' in args[1:]
552
553    def _add_connection(self, cast):
554        """Add a connection argument to the typecast function if necessary."""
555        if not self.connection or not self._needs_connection(cast):
556            return cast
557        return partial(cast, connection=self.connection)
558
559    def get(self, typ, default=None):
560        """Get the typecast function for the given database type."""
561        return self[typ] or default
562
563    def set(self, typ, cast):
564        """Set a typecast function for the specified database type(s)."""
565        if isinstance(typ, basestring):
566            typ = [typ]
567        if cast is None:
568            for t in typ:
569                self.pop(t, None)
570                self.pop('_%s' % t, None)
571        else:
572            if not callable(cast):
573                raise TypeError("Cast parameter must be callable")
574            for t in typ:
575                self[t] = self._add_connection(cast)
576                self.pop('_%s' % t, None)
577
578    def reset(self, typ=None):
579        """Reset the typecasts for the specified type(s) to their defaults.
580
581        When no type is specified, all typecasts will be reset.
582        """
583        defaults = self.defaults
584        if typ is None:
585            self.clear()
586            self.update(defaults)
587        else:
588            if isinstance(typ, basestring):
589                typ = [typ]
590            for t in typ:
591                cast = defaults.get(t)
592                if cast:
593                    self[t] = self._add_connection(cast)
594                    t = '_%s' % t
595                    cast = defaults.get(t)
596                    if cast:
597                        self[t] = self._add_connection(cast)
598                    else:
599                        self.pop(t, None)
600                else:
601                    self.pop(t, None)
602                    self.pop('_%s' % t, None)
603
604    def create_array_cast(self, basecast):
605        """Create an array typecast for the given base cast."""
606        cast_array = self['anyarray']
607        def cast(v):
608            return cast_array(v, basecast)
609        return cast
610
611    def create_record_cast(self, name, fields, casts):
612        """Create a named record typecast for the given fields and casts."""
613        cast_record = self['record']
614        record = namedtuple(name, fields)
615        def cast(v):
616            return record(*cast_record(v, casts))
617        return cast
618
619
620_typecasts = Typecasts()  # this is the global typecast dictionary
621
622
623def get_typecast(typ):
624    """Get the global typecast function for the given database type(s)."""
625    return _typecasts.get(typ)
626
627
628def set_typecast(typ, cast):
629    """Set a global typecast function for the given database type(s).
630
631    Note that connections cache cast functions. To be sure a global change
632    is picked up by a running connection, call con.type_cache.reset_typecast().
633    """
634    _typecasts.set(typ, cast)
635
636
637def reset_typecast(typ=None):
638    """Reset the global typecasts for the given type(s) to their default.
639
640    When no type is specified, all typecasts will be reset.
641
642    Note that connections cache cast functions. To be sure a global change
643    is picked up by a running connection, call con.type_cache.reset_typecast().
644    """
645    _typecasts.reset(typ)
646
647
648class LocalTypecasts(Typecasts):
649    """Map typecasts, including local composite types, to cast functions."""
650
651    defaults = _typecasts
652
653    connection = None  # will be set in a connection specific instance
654
655    def __missing__(self, typ):
656        """Create a cast function if it is not cached."""
657        if typ.startswith('_'):
658            base_cast = self[typ[1:]]
659            cast = self.create_array_cast(base_cast)
660            if base_cast:
661                self[typ] = cast
662        else:
663            cast = self.defaults.get(typ)
664            if cast:
665                cast = self._add_connection(cast)
666                self[typ] = cast
667            else:
668                fields = self.get_fields(typ)
669                if fields:
670                    casts = [self[field.type] for field in fields]
671                    fields = [field.name for field in fields]
672                    cast = self.create_record_cast(typ, fields, casts)
673                    self[typ] = cast
674        return cast
675
676    def get_fields(self, typ):
677        """Return the fields for the given record type.
678
679        This method will be replaced with a method that looks up the fields
680        using the type cache of the connection.
681        """
682        return []
683
684
685class TypeCode(str):
686    """Class representing the type_code used by the DB-API 2.0.
687
688    TypeCode objects are strings equal to the PostgreSQL type name,
689    but carry some additional information.
690    """
691
692    @classmethod
693    def create(cls, oid, name, len, type, category, delim, relid):
694        """Create a type code for a PostgreSQL data type."""
695        self = cls(name)
696        self.oid = oid
697        self.len = len
698        self.type = type
699        self.category = category
700        self.delim = delim
701        self.relid = relid
702        return self
703
704FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
705
706
707class TypeCache(dict):
708    """Cache for database types.
709
710    This cache maps type OIDs and names to TypeCode strings containing
711    important information on the associated database type.
712    """
713
714    def __init__(self, cnx):
715        """Initialize type cache for connection."""
716        super(TypeCache, self).__init__()
717        self._escape_string = cnx.escape_string
718        self._src = cnx.source()
719        self._typecasts = LocalTypecasts()
720        self._typecasts.get_fields = self.get_fields
721        self._typecasts.connection = cnx
722        if cnx.server_version < 80400:
723            # older remote databases (not officially supported)
724            self._query_pg_type = ("SELECT oid, typname,"
725                " typlen, typtype, null as typcategory, typdelim, typrelid"
726                " FROM pg_type WHERE oid=%s")
727        else:
728            self._query_pg_type = ("SELECT oid, typname,"
729                " typlen, typtype, typcategory, typdelim, typrelid"
730                " FROM pg_type WHERE oid=%s")
731
732    def __missing__(self, key):
733        """Get the type info from the database if it is not cached."""
734        if isinstance(key, int):
735            oid = key
736        else:
737            if '.' not in key and '"' not in key:
738                key = '"%s"' % (key,)
739            oid = "'%s'::regtype" % (self._escape_string(key),)
740        try:
741            self._src.execute(self._query_pg_type % (oid,))
742        except ProgrammingError:
743            res = None
744        else:
745            res = self._src.fetch(1)
746        if not res:
747            raise KeyError('Type %s could not be found' % (key,))
748        res = res[0]
749        type_code = TypeCode.create(int(res[0]), res[1],
750            int(res[2]), res[3], res[4], res[5], int(res[6]))
751        self[type_code.oid] = self[str(type_code)] = type_code
752        return type_code
753
754    def get(self, key, default=None):
755        """Get the type even if it is not cached."""
756        try:
757            return self[key]
758        except KeyError:
759            return default
760
761    def get_fields(self, typ):
762        """Get the names and types of the fields of composite types."""
763        if not isinstance(typ, TypeCode):
764            typ = self.get(typ)
765            if not typ:
766                return None
767        if not typ.relid:
768            return None  # this type is not composite
769        self._src.execute("SELECT attname, atttypid"
770            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
771            " AND NOT attisdropped ORDER BY attnum" % (typ.relid,))
772        return [FieldInfo(name, self.get(int(oid)))
773            for name, oid in self._src.fetch(-1)]
774
775    def get_typecast(self, typ):
776        """Get the typecast function for the given database type."""
777        return self._typecasts.get(typ)
778
779    def set_typecast(self, typ, cast):
780        """Set a typecast function for the specified database type(s)."""
781        self._typecasts.set(typ, cast)
782
783    def reset_typecast(self, typ=None):
784        """Reset the typecast function for the specified database type(s)."""
785        self._typecasts.reset(typ)
786
787    def typecast(self, value, typ):
788        """Cast the given value according to the given database type."""
789        if value is None:
790            # for NULL values, no typecast is necessary
791            return None
792        cast = self.get_typecast(typ)
793        if not cast or cast is str:
794            # no typecast is necessary
795            return value
796        return cast(value)
797
798
799class _quotedict(dict):
800    """Dictionary with auto quoting of its items.
801
802    The quote attribute must be set to the desired quote function.
803    """
804
805    def __getitem__(self, key):
806        return self.quote(super(_quotedict, self).__getitem__(key))
807
808
809### Error Messages
810
811def _db_error(msg, cls=DatabaseError):
812    """Return DatabaseError with empty sqlstate attribute."""
813    error = cls(msg)
814    error.sqlstate = None
815    return error
816
817
818def _op_error(msg):
819    """Return OperationalError."""
820    return _db_error(msg, OperationalError)
821
822
823### Row Tuples
824
825_re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$')
826
827# The result rows for database operations are returned as named tuples
828# by default. Since creating namedtuple classes is a somewhat expensive
829# operation, we cache up to 1024 of these classes by default.
830
831@lru_cache(maxsize=1024)
832def _row_factory(names):
833    """Get a namedtuple factory for row results with the given names."""
834    try:
835        try:
836            return namedtuple('Row', names, rename=True)._make
837        except TypeError:  # Python 2.6 and 3.0 do not support rename
838            names = [v if _re_fieldname.match(v) and not iskeyword(v)
839                        else 'column_%d' % (n,)
840                     for n, v in enumerate(names)]
841            return namedtuple('Row', names)._make
842    except ValueError:  # there is still a problem with the field names
843        names = ['column_%d' % (n,) for n in range(len(names))]
844        return namedtuple('Row', names)._make
845
846
847def set_row_factory_size(maxsize):
848    """Change the size of the namedtuple factory cache.
849
850    If maxsize is set to None, the cache can grow without bound.
851    """
852    global _row_factory
853    _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__)
854
855
856### Cursor Object
857
858class Cursor(object):
859    """Cursor object."""
860
861    def __init__(self, dbcnx):
862        """Create a cursor object for the database connection."""
863        self.connection = self._dbcnx = dbcnx
864        self._cnx = dbcnx._cnx
865        self.type_cache = dbcnx.type_cache
866        self._src = self._cnx.source()
867        # the official attribute for describing the result columns
868        self._description = None
869        if self.row_factory is Cursor.row_factory:
870            # the row factory needs to be determined dynamically
871            self.row_factory = None
872        else:
873            self.build_row_factory = None
874        self.rowcount = -1
875        self.arraysize = 1
876        self.lastrowid = None
877
878    def __iter__(self):
879        """Make cursor compatible to the iteration protocol."""
880        return self
881
882    def __enter__(self):
883        """Enter the runtime context for the cursor object."""
884        return self
885
886    def __exit__(self, et, ev, tb):
887        """Exit the runtime context for the cursor object."""
888        self.close()
889
890    def _quote(self, value):
891        """Quote value depending on its type."""
892        if value is None:
893            return 'NULL'
894        if isinstance(value, (Hstore, Json)):
895            value = str(value)
896        if isinstance(value, basestring):
897            if isinstance(value, Binary):
898                value = self._cnx.escape_bytea(value)
899                if bytes is not str:  # Python >= 3.0
900                    value = value.decode('ascii')
901            else:
902                value = self._cnx.escape_string(value)
903            return "'%s'" % (value,)
904        if isinstance(value, float):
905            if isinf(value):
906                return "'-Infinity'" if value < 0 else "'Infinity'"
907            if isnan(value):
908                return "'NaN'"
909            return value
910        if isinstance(value, (int, long, Decimal, Literal)):
911            return value
912        if isinstance(value, datetime):
913            if value.tzinfo:
914                return "'%s'::timestamptz" % (value,)
915            return "'%s'::timestamp" % (value,)
916        if isinstance(value, date):
917            return "'%s'::date" % (value,)
918        if isinstance(value, time):
919            if value.tzinfo:
920                return "'%s'::timetz" % (value,)
921            return "'%s'::time" % value
922        if isinstance(value, timedelta):
923            return "'%s'::interval" % (value,)
924        if isinstance(value, Uuid):
925            return "'%s'::uuid" % (value,)
926        if isinstance(value, list):
927            # Quote value as an ARRAY constructor. This is better than using
928            # an array literal because it carries the information that this is
929            # an array and not a string.  One issue with this syntax is that
930            # you need to add an explicit typecast when passing empty arrays.
931            # The ARRAY keyword is actually only necessary at the top level.
932            if not value:  # exception for empty array
933                return "'{}'"
934            q = self._quote
935            try:
936                return 'ARRAY[%s]' % (','.join(str(q(v)) for v in value),)
937            except UnicodeEncodeError:  # Python 2 with non-ascii values
938                return u'ARRAY[%s]' % (','.join(unicode(q(v)) for v in value),)
939        if isinstance(value, tuple):
940            # Quote as a ROW constructor.  This is better than using a record
941            # literal because it carries the information that this is a record
942            # and not a string.  We don't use the keyword ROW in order to make
943            # this usable with the IN syntax as well.  It is only necessary
944            # when the records has a single column which is not really useful.
945            q = self._quote
946            try:
947                return '(%s)' % (','.join(str(q(v)) for v in value),)
948            except UnicodeEncodeError:  # Python 2 with non-ascii values
949                return u'(%s)' % (','.join(unicode(q(v)) for v in value),)
950        try:
951            value = value.__pg_repr__()
952        except AttributeError:
953            raise InterfaceError(
954                'Do not know how to adapt type %s' % (type(value),))
955        if isinstance(value, (tuple, list)):
956            value = self._quote(value)
957        return value
958
959    def _quoteparams(self, string, parameters):
960        """Quote parameters.
961
962        This function works for both mappings and sequences.
963
964        The function should be used even when there are no parameters,
965        so that we have a consistent behavior regarding percent signs.
966        """
967        if not parameters:
968            try:
969                return string % ()  # unescape literal quotes if possible
970            except (TypeError, ValueError):
971                return string  # silently accept unescaped quotes
972        if isinstance(parameters, dict):
973            parameters = _quotedict(parameters)
974            parameters.quote = self._quote
975        else:
976            parameters = tuple(map(self._quote, parameters))
977        return string % parameters
978
979    def _make_description(self, info):
980        """Make the description tuple for the given field info."""
981        name, typ, size, mod = info[1:]
982        type_code = self.type_cache[typ]
983        if mod > 0:
984            mod -= 4
985        if type_code == 'numeric':
986            precision, scale = mod >> 16, mod & 0xffff
987            size = precision
988        else:
989            if not size:
990                size = type_code.size
991            if size == -1:
992                size = mod
993            precision = scale = None
994        return CursorDescription(name, type_code,
995            None, size, precision, scale, None)
996
997    @property
998    def description(self):
999        """Read-only attribute describing the result columns."""
1000        descr = self._description
1001        if self._description is True:
1002            make = self._make_description
1003            descr = [make(info) for info in self._src.listinfo()]
1004            self._description = descr
1005        return descr
1006
1007    @property
1008    def colnames(self):
1009        """Unofficial convenience method for getting the column names."""
1010        return [d[0] for d in self.description]
1011
1012    @property
1013    def coltypes(self):
1014        """Unofficial convenience method for getting the column types."""
1015        return [d[1] for d in self.description]
1016
1017    def close(self):
1018        """Close the cursor object."""
1019        self._src.close()
1020
1021    def execute(self, operation, parameters=None):
1022        """Prepare and execute a database operation (query or command)."""
1023        # The parameters may also be specified as list of tuples to e.g.
1024        # insert multiple rows in a single operation, but this kind of
1025        # usage is deprecated.  We make several plausibility checks because
1026        # tuples can also be passed with the meaning of ROW constructors.
1027        if (parameters and isinstance(parameters, list)
1028                and len(parameters) > 1
1029                and all(isinstance(p, tuple) for p in parameters)
1030                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
1031            return self.executemany(operation, parameters)
1032        else:
1033            # not a list of tuples
1034            return self.executemany(operation, [parameters])
1035
1036    def executemany(self, operation, seq_of_parameters):
1037        """Prepare operation and execute it against a parameter sequence."""
1038        if not seq_of_parameters:
1039            # don't do anything without parameters
1040            return
1041        self._description = None
1042        self.rowcount = -1
1043        # first try to execute all queries
1044        rowcount = 0
1045        sql = "BEGIN"
1046        try:
1047            if not self._dbcnx._tnx:
1048                try:
1049                    self._src.execute(sql)
1050                except DatabaseError:
1051                    raise  # database provides error message
1052                except Exception:
1053                    raise _op_error("Can't start transaction")
1054                self._dbcnx._tnx = True
1055            for parameters in seq_of_parameters:
1056                sql = operation
1057                sql = self._quoteparams(sql, parameters)
1058                rows = self._src.execute(sql)
1059                if rows:  # true if not DML
1060                    rowcount += rows
1061                else:
1062                    self.rowcount = -1
1063        except DatabaseError:
1064            raise  # database provides error message
1065        except Error as err:
1066            raise _db_error(
1067                "Error in '%s': '%s' " % (sql, err), InterfaceError)
1068        except Exception as err:
1069            raise _op_error("Internal error in '%s': %s" % (sql, err))
1070        # then initialize result raw count and description
1071        if self._src.resulttype == RESULT_DQL:
1072            self._description = True  # fetch on demand
1073            self.rowcount = self._src.ntuples
1074            self.lastrowid = None
1075            if self.build_row_factory:
1076                self.row_factory = self.build_row_factory()
1077        else:
1078            self.rowcount = rowcount
1079            self.lastrowid = self._src.oidstatus()
1080        # return the cursor object, so you can write statements such as
1081        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
1082        return self
1083
1084    def fetchone(self):
1085        """Fetch the next row of a query result set."""
1086        res = self.fetchmany(1, False)
1087        try:
1088            return res[0]
1089        except IndexError:
1090            return None
1091
1092    def fetchall(self):
1093        """Fetch all (remaining) rows of a query result."""
1094        return self.fetchmany(-1, False)
1095
1096    def fetchmany(self, size=None, keep=False):
1097        """Fetch the next set of rows of a query result.
1098
1099        The number of rows to fetch per call is specified by the
1100        size parameter. If it is not given, the cursor's arraysize
1101        determines the number of rows to be fetched. If you set
1102        the keep parameter to true, this is kept as new arraysize.
1103        """
1104        if size is None:
1105            size = self.arraysize
1106        if keep:
1107            self.arraysize = size
1108        try:
1109            result = self._src.fetch(size)
1110        except DatabaseError:
1111            raise
1112        except Error as err:
1113            raise _db_error(str(err))
1114        typecast = self.type_cache.typecast
1115        return [self.row_factory([typecast(value, typ)
1116            for typ, value in zip(self.coltypes, row)]) for row in result]
1117
1118    def callproc(self, procname, parameters=None):
1119        """Call a stored database procedure with the given name.
1120
1121        The sequence of parameters must contain one entry for each input
1122        argument that the procedure expects. The result of the call is the
1123        same as this input sequence; replacement of output and input/output
1124        parameters in the return value is currently not supported.
1125
1126        The procedure may also provide a result set as output. These can be
1127        requested through the standard fetch methods of the cursor.
1128        """
1129        n = parameters and len(parameters) or 0
1130        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
1131        self.execute(query, parameters)
1132        return parameters
1133
1134    def copy_from(self, stream, table,
1135            format=None, sep=None, null=None, size=None, columns=None):
1136        """Copy data from an input stream to the specified table.
1137
1138        The input stream can be a file-like object with a read() method or
1139        it can also be an iterable returning a row or multiple rows of input
1140        on each iteration.
1141
1142        The format must be text, csv or binary. The sep option sets the
1143        column separator (delimiter) used in the non binary formats.
1144        The null option sets the textual representation of NULL in the input.
1145
1146        The size option sets the size of the buffer used when reading data
1147        from file-like objects.
1148
1149        The copy operation can be restricted to a subset of columns. If no
1150        columns are specified, all of them will be copied.
1151        """
1152        binary_format = format == 'binary'
1153        try:
1154            read = stream.read
1155        except AttributeError:
1156            if size:
1157                raise ValueError("Size must only be set for file-like objects")
1158            if binary_format:
1159                input_type = bytes
1160                type_name = 'byte strings'
1161            else:
1162                input_type = basestring
1163                type_name = 'strings'
1164
1165            if isinstance(stream, basestring):
1166                if not isinstance(stream, input_type):
1167                    raise ValueError("The input must be %s" % (type_name,))
1168                if not binary_format:
1169                    if isinstance(stream, str):
1170                        if not stream.endswith('\n'):
1171                            stream += '\n'
1172                    else:
1173                        if not stream.endswith(b'\n'):
1174                            stream += b'\n'
1175
1176                def chunks():
1177                    yield stream
1178
1179            elif isinstance(stream, Iterable):
1180
1181                def chunks():
1182                    for chunk in stream:
1183                        if not isinstance(chunk, input_type):
1184                            raise ValueError(
1185                                "Input stream must consist of %s"
1186                                % (type_name,))
1187                        if isinstance(chunk, str):
1188                            if not chunk.endswith('\n'):
1189                                chunk += '\n'
1190                        else:
1191                            if not chunk.endswith(b'\n'):
1192                                chunk += b'\n'
1193                        yield chunk
1194
1195            else:
1196                raise TypeError("Need an input stream to copy from")
1197        else:
1198            if size is None:
1199                size = 8192
1200            elif not isinstance(size, int):
1201                raise TypeError("The size option must be an integer")
1202            if size > 0:
1203
1204                def chunks():
1205                    while True:
1206                        buffer = read(size)
1207                        yield buffer
1208                        if not buffer or len(buffer) < size:
1209                            break
1210
1211            else:
1212
1213                def chunks():
1214                    yield read()
1215
1216        if not table or not isinstance(table, basestring):
1217            raise TypeError("Need a table to copy to")
1218        if table.lower().startswith('select'):
1219                raise ValueError("Must specify a table, not a query")
1220        else:
1221            table = '"%s"' % (table,)
1222        operation = ['copy %s' % (table,)]
1223        options = []
1224        params = []
1225        if format is not None:
1226            if not isinstance(format, basestring):
1227                raise TypeError("The frmat option must be be a string")
1228            if format not in ('text', 'csv', 'binary'):
1229                raise ValueError("Invalid format")
1230            options.append('format %s' % (format,))
1231        if sep is not None:
1232            if not isinstance(sep, basestring):
1233                raise TypeError("The sep option must be a string")
1234            if format == 'binary':
1235                raise ValueError(
1236                    "The sep option is not allowed with binary format")
1237            if len(sep) != 1:
1238                raise ValueError(
1239                    "The sep option must be a single one-byte character")
1240            options.append('delimiter %s')
1241            params.append(sep)
1242        if null is not None:
1243            if not isinstance(null, basestring):
1244                raise TypeError("The null option must be a string")
1245            options.append('null %s')
1246            params.append(null)
1247        if columns:
1248            if not isinstance(columns, basestring):
1249                columns = ','.join('"%s"' % (col,) for col in columns)
1250            operation.append('(%s)' % (columns,))
1251        operation.append("from stdin")
1252        if options:
1253            operation.append('(%s)' % (','.join(options),))
1254        operation = ' '.join(operation)
1255
1256        putdata = self._src.putdata
1257        self.execute(operation, params)
1258
1259        try:
1260            for chunk in chunks():
1261                putdata(chunk)
1262        except BaseException as error:
1263            self.rowcount = -1
1264            # the following call will re-raise the error
1265            putdata(error)
1266        else:
1267            self.rowcount = putdata(None)
1268
1269        # return the cursor object, so you can chain operations
1270        return self
1271
1272    def copy_to(self, stream, table,
1273            format=None, sep=None, null=None, decode=None, columns=None):
1274        """Copy data from the specified table to an output stream.
1275
1276        The output stream can be a file-like object with a write() method or
1277        it can also be None, in which case the method will return a generator
1278        yielding a row on each iteration.
1279
1280        Output will be returned as byte strings unless you set decode to true.
1281
1282        Note that you can also use a select query instead of the table name.
1283
1284        The format must be text, csv or binary. The sep option sets the
1285        column separator (delimiter) used in the non binary formats.
1286        The null option sets the textual representation of NULL in the output.
1287
1288        The copy operation can be restricted to a subset of columns. If no
1289        columns are specified, all of them will be copied.
1290        """
1291        binary_format = format == 'binary'
1292        if stream is not None:
1293            try:
1294                write = stream.write
1295            except AttributeError:
1296                raise TypeError("Need an output stream to copy to")
1297        if not table or not isinstance(table, basestring):
1298            raise TypeError("Need a table to copy to")
1299        if table.lower().startswith('select'):
1300            if columns:
1301                raise ValueError("Columns must be specified in the query")
1302            table = '(%s)' % (table,)
1303        else:
1304            table = '"%s"' % (table,)
1305        operation = ['copy %s' % (table,)]
1306        options = []
1307        params = []
1308        if format is not None:
1309            if not isinstance(format, basestring):
1310                raise TypeError("The format option must be a string")
1311            if format not in ('text', 'csv', 'binary'):
1312                raise ValueError("Invalid format")
1313            options.append('format %s' % (format,))
1314        if sep is not None:
1315            if not isinstance(sep, basestring):
1316                raise TypeError("The sep option must be a string")
1317            if binary_format:
1318                raise ValueError(
1319                    "The sep option is not allowed with binary format")
1320            if len(sep) != 1:
1321                raise ValueError(
1322                    "The sep option must be a single one-byte character")
1323            options.append('delimiter %s')
1324            params.append(sep)
1325        if null is not None:
1326            if not isinstance(null, basestring):
1327                raise TypeError("The null option must be a string")
1328            options.append('null %s')
1329            params.append(null)
1330        if decode is None:
1331            if format == 'binary':
1332                decode = False
1333            else:
1334                decode = str is unicode
1335        else:
1336            if not isinstance(decode, (int, bool)):
1337                raise TypeError("The decode option must be a boolean")
1338            if decode and binary_format:
1339                raise ValueError(
1340                    "The decode option is not allowed with binary format")
1341        if columns:
1342            if not isinstance(columns, basestring):
1343                columns = ','.join('"%s"' % (col,) for col in columns)
1344            operation.append('(%s)' % (columns,))
1345
1346        operation.append("to stdout")
1347        if options:
1348            operation.append('(%s)' % (','.join(options),))
1349        operation = ' '.join(operation)
1350
1351        getdata = self._src.getdata
1352        self.execute(operation, params)
1353
1354        def copy():
1355            self.rowcount = 0
1356            while True:
1357                row = getdata(decode)
1358                if isinstance(row, int):
1359                    if self.rowcount != row:
1360                        self.rowcount = row
1361                    break
1362                self.rowcount += 1
1363                yield row
1364
1365        if stream is None:
1366            # no input stream, return the generator
1367            return copy()
1368
1369        # write the rows to the file-like input stream
1370        for row in copy():
1371            write(row)
1372
1373        # return the cursor object, so you can chain operations
1374        return self
1375
1376    def __next__(self):
1377        """Return the next row (support for the iteration protocol)."""
1378        res = self.fetchone()
1379        if res is None:
1380            raise StopIteration
1381        return res
1382
1383    # Note that since Python 2.6 the iterator protocol uses __next()__
1384    # instead of next(), we keep it only for backward compatibility of pgdb.
1385    next = __next__
1386
1387    @staticmethod
1388    def nextset():
1389        """Not supported."""
1390        raise NotSupportedError("The nextset() method is not supported")
1391
1392    @staticmethod
1393    def setinputsizes(sizes):
1394        """Not supported."""
1395        pass  # unsupported, but silently passed
1396
1397    @staticmethod
1398    def setoutputsize(size, column=0):
1399        """Not supported."""
1400        pass  # unsupported, but silently passed
1401
1402    @staticmethod
1403    def row_factory(row):
1404        """Process rows before they are returned.
1405
1406        You can overwrite this statically with a custom row factory, or
1407        you can build a row factory dynamically with build_row_factory().
1408
1409        For example, you can create a Cursor class that returns rows as
1410        Python dictionaries like this:
1411
1412            class DictCursor(pgdb.Cursor):
1413
1414                def row_factory(self, row):
1415                    return {desc[0]: value
1416                        for desc, value in zip(self.description, row)}
1417
1418            cur = DictCursor(con)  # get one DictCursor instance or
1419            con.cursor_type = DictCursor  # always use DictCursor instances
1420        """
1421        raise NotImplementedError
1422
1423    def build_row_factory(self):
1424        """Build a row factory based on the current description.
1425
1426        This implementation builds a row factory for creating named tuples.
1427        You can overwrite this method if you want to dynamically create
1428        different row factories whenever the column description changes.
1429        """
1430        names = self.colnames
1431        if names:
1432            return _row_factory(tuple(names))
1433
1434
1435
1436CursorDescription = namedtuple('CursorDescription',
1437    ['name', 'type_code', 'display_size', 'internal_size',
1438     'precision', 'scale', 'null_ok'])
1439
1440
1441### Connection Objects
1442
1443class Connection(object):
1444    """Connection object."""
1445
1446    # expose the exceptions as attributes on the connection object
1447    Error = Error
1448    Warning = Warning
1449    InterfaceError = InterfaceError
1450    DatabaseError = DatabaseError
1451    InternalError = InternalError
1452    OperationalError = OperationalError
1453    ProgrammingError = ProgrammingError
1454    IntegrityError = IntegrityError
1455    DataError = DataError
1456    NotSupportedError = NotSupportedError
1457
1458    def __init__(self, cnx):
1459        """Create a database connection object."""
1460        self._cnx = cnx  # connection
1461        self._tnx = False  # transaction state
1462        self.type_cache = TypeCache(cnx)
1463        self.cursor_type = Cursor
1464        try:
1465            self._cnx.source()
1466        except Exception:
1467            raise _op_error("Invalid connection")
1468
1469    def __enter__(self):
1470        """Enter the runtime context for the connection object.
1471
1472        The runtime context can be used for running transactions.
1473        """
1474        return self
1475
1476    def __exit__(self, et, ev, tb):
1477        """Exit the runtime context for the connection object.
1478
1479        This does not close the connection, but it ends a transaction.
1480        """
1481        if et is None and ev is None and tb is None:
1482            self.commit()
1483        else:
1484            self.rollback()
1485
1486    def close(self):
1487        """Close the connection object."""
1488        if self._cnx:
1489            if self._tnx:
1490                try:
1491                    self.rollback()
1492                except DatabaseError:
1493                    pass
1494            self._cnx.close()
1495            self._cnx = None
1496        else:
1497            raise _op_error("Connection has been closed")
1498
1499    @property
1500    def closed(self):
1501        """Check whether the connection has been closed or is broken."""
1502        try:
1503            return not self._cnx or self._cnx.status != 1
1504        except TypeError:
1505            return True
1506
1507    def commit(self):
1508        """Commit any pending transaction to the database."""
1509        if self._cnx:
1510            if self._tnx:
1511                self._tnx = False
1512                try:
1513                    self._cnx.source().execute("COMMIT")
1514                except DatabaseError:
1515                    raise
1516                except Exception:
1517                    raise _op_error("Can't commit")
1518        else:
1519            raise _op_error("Connection has been closed")
1520
1521    def rollback(self):
1522        """Roll back to the start of any pending transaction."""
1523        if self._cnx:
1524            if self._tnx:
1525                self._tnx = False
1526                try:
1527                    self._cnx.source().execute("ROLLBACK")
1528                except DatabaseError:
1529                    raise
1530                except Exception:
1531                    raise _op_error("Can't rollback")
1532        else:
1533            raise _op_error("Connection has been closed")
1534
1535    def cursor(self):
1536        """Return a new cursor object using the connection."""
1537        if self._cnx:
1538            try:
1539                return self.cursor_type(self)
1540            except Exception:
1541                raise _op_error("Invalid connection")
1542        else:
1543            raise _op_error("Connection has been closed")
1544
1545    if shortcutmethods:  # otherwise do not implement and document this
1546
1547        def execute(self, operation, params=None):
1548            """Shortcut method to run an operation on an implicit cursor."""
1549            cursor = self.cursor()
1550            cursor.execute(operation, params)
1551            return cursor
1552
1553        def executemany(self, operation, param_seq):
1554            """Shortcut method to run an operation against a sequence."""
1555            cursor = self.cursor()
1556            cursor.executemany(operation, param_seq)
1557            return cursor
1558
1559
1560### Module Interface
1561
1562_connect = connect
1563
1564def connect(dsn=None,
1565        user=None, password=None,
1566        host=None, database=None, **kwargs):
1567    """Connect to a database."""
1568    # first get params from DSN
1569    dbport = -1
1570    dbhost = ""
1571    dbname = ""
1572    dbuser = ""
1573    dbpasswd = ""
1574    dbopt = ""
1575    try:
1576        params = dsn.split(":")
1577        dbhost = params[0]
1578        dbname = params[1]
1579        dbuser = params[2]
1580        dbpasswd = params[3]
1581        dbopt = params[4]
1582    except (AttributeError, IndexError, TypeError):
1583        pass
1584
1585    # override if necessary
1586    if user is not None:
1587        dbuser = user
1588    if password is not None:
1589        dbpasswd = password
1590    if database is not None:
1591        dbname = database
1592    if host is not None:
1593        try:
1594            params = host.split(":")
1595            dbhost = params[0]
1596            dbport = int(params[1])
1597        except (AttributeError, IndexError, TypeError, ValueError):
1598            pass
1599
1600    # empty host is localhost
1601    if dbhost == "":
1602        dbhost = None
1603    if dbuser == "":
1604        dbuser = None
1605
1606    # pass keyword arguments as connection info string
1607    if kwargs:
1608        kwargs = list(kwargs.items())
1609        if '=' in dbname:
1610            dbname = [dbname]
1611        else:
1612            kwargs.insert(0, ('dbname', dbname))
1613            dbname = []
1614        for kw, value in kwargs:
1615            value = str(value)
1616            if not value or ' ' in value:
1617                value = "'%s'" % (value.replace(
1618                    "'", "\\'").replace('\\', '\\\\'),)
1619            dbname.append('%s=%s' % (kw, value))
1620        dbname = ' '.join(dbname)
1621
1622    # open the connection
1623    cnx = _connect(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd)
1624    return Connection(cnx)
1625
1626
1627### Types Handling
1628
1629class Type(frozenset):
1630    """Type class for a couple of PostgreSQL data types.
1631
1632    PostgreSQL is object-oriented: types are dynamic.
1633    We must thus use type names as internal type codes.
1634    """
1635
1636    def __new__(cls, values):
1637        if isinstance(values, basestring):
1638            values = values.split()
1639        return super(Type, cls).__new__(cls, values)
1640
1641    def __eq__(self, other):
1642        if isinstance(other, basestring):
1643            if other.startswith('_'):
1644                other = other[1:]
1645            return other in self
1646        else:
1647            return super(Type, self).__eq__(other)
1648
1649    def __ne__(self, other):
1650        if isinstance(other, basestring):
1651            if other.startswith('_'):
1652                other = other[1:]
1653            return other not in self
1654        else:
1655            return super(Type, self).__ne__(other)
1656
1657
1658class ArrayType:
1659    """Type class for PostgreSQL array types."""
1660
1661    def __eq__(self, other):
1662        if isinstance(other, basestring):
1663            return other.startswith('_')
1664        else:
1665            return isinstance(other, ArrayType)
1666
1667    def __ne__(self, other):
1668        if isinstance(other, basestring):
1669            return not other.startswith('_')
1670        else:
1671            return not isinstance(other, ArrayType)
1672
1673
1674class RecordType:
1675    """Type class for PostgreSQL record types."""
1676
1677    def __eq__(self, other):
1678        if isinstance(other, TypeCode):
1679            return other.type == 'c'
1680        elif isinstance(other, basestring):
1681            return other == 'record'
1682        else:
1683            return isinstance(other, RecordType)
1684
1685    def __ne__(self, other):
1686        if isinstance(other, TypeCode):
1687            return other.type != 'c'
1688        elif isinstance(other, basestring):
1689            return other != 'record'
1690        else:
1691            return not isinstance(other, RecordType)
1692
1693
1694# Mandatory type objects defined by DB-API 2 specs:
1695
1696STRING = Type('char bpchar name text varchar')
1697BINARY = Type('bytea')
1698NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1699DATETIME = Type('date time timetz timestamp timestamptz interval'
1700    ' abstime reltime')  # these are very old
1701ROWID = Type('oid')
1702
1703
1704# Additional type objects (more specific):
1705
1706BOOL = Type('bool')
1707SMALLINT = Type('int2')
1708INTEGER = Type('int2 int4 int8 serial')
1709LONG = Type('int8')
1710FLOAT = Type('float4 float8')
1711NUMERIC = Type('numeric')
1712MONEY = Type('money')
1713DATE = Type('date')
1714TIME = Type('time timetz')
1715TIMESTAMP = Type('timestamp timestamptz')
1716INTERVAL = Type('interval')
1717UUID = Type('uuid')
1718HSTORE = Type('hstore')
1719JSON = Type('json jsonb')
1720
1721# Type object for arrays (also equate to their base types):
1722
1723ARRAY = ArrayType()
1724
1725# Type object for records (encompassing all composite types):
1726
1727RECORD = RecordType()
1728
1729
1730# Mandatory type helpers defined by DB-API 2 specs:
1731
1732def Date(year, month, day):
1733    """Construct an object holding a date value."""
1734    return date(year, month, day)
1735
1736
1737def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None):
1738    """Construct an object holding a time value."""
1739    return time(hour, minute, second, microsecond, tzinfo)
1740
1741
1742def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0,
1743        tzinfo=None):
1744    """Construct an object holding a time stamp value."""
1745    return datetime(year, month, day, hour, minute, second, microsecond, tzinfo)
1746
1747
1748def DateFromTicks(ticks):
1749    """Construct an object holding a date value from the given ticks value."""
1750    return Date(*localtime(ticks)[:3])
1751
1752
1753def TimeFromTicks(ticks):
1754    """Construct an object holding a time value from the given ticks value."""
1755    return Time(*localtime(ticks)[3:6])
1756
1757
1758def TimestampFromTicks(ticks):
1759    """Construct an object holding a time stamp from the given ticks value."""
1760    return Timestamp(*localtime(ticks)[:6])
1761
1762
1763class Binary(bytes):
1764    """Construct an object capable of holding a binary (long) string value."""
1765
1766
1767# Additional type helpers for PyGreSQL:
1768
1769def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0):
1770    """Construct an object holding a time inverval value."""
1771    return timedelta(days, hours=hours, minutes=minutes, seconds=seconds,
1772        microseconds=microseconds)
1773
1774
1775Uuid = Uuid  # Construct an object holding a UUID value
1776
1777
1778class Hstore(dict):
1779    """Wrapper class for marking hstore values."""
1780
1781    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
1782    _re_escape = regex(r'(["\\])')
1783
1784    @classmethod
1785    def _quote(cls, s):
1786        if s is None:
1787            return 'NULL'
1788        if not s:
1789            return '""'
1790        quote = cls._re_quote.search(s)
1791        s = cls._re_escape.sub(r'\\\1', s)
1792        if quote:
1793            s = '"%s"' % (s,)
1794        return s
1795
1796    def __str__(self):
1797        q = self._quote
1798        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
1799
1800
1801class Json:
1802    """Construct a wrapper for holding an object serializable to JSON."""
1803
1804    def __init__(self, obj, encode=None):
1805        self.obj = obj
1806        self.encode = encode or jsonencode
1807
1808    def __str__(self):
1809        obj = self.obj
1810        if isinstance(obj, basestring):
1811            return obj
1812        return self.encode(obj)
1813
1814
1815class Literal:
1816    """Construct a wrapper for holding a literal SQL string."""
1817
1818    def __init__(self, sql):
1819        self.sql = sql
1820
1821    def __str__(self):
1822        return self.sql
1823
1824    __pg_repr__ = __str__
1825
1826
1827# If run as script, print some information:
1828
1829if __name__ == '__main__':
1830    print('PyGreSQL version', version)
1831    print('')
1832    print(__doc__)
Note: See TracBrowser for help on using the repository browser.