source: trunk/pgdb.py @ 846

Last change on this file since 846 was 846, checked in by cito, 4 years ago

Do not reset cursor attributes when cursor is closed

SQLAlchemy for instance checks the rowcount after the cursor is closed.
You will only get an error if you try to fetch rows from a closed cursor.

  • Property svn:keywords set to Id
File size: 54.1 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 846 2016-02-09 07:20:10Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70__version__ = version
71
72from datetime import date, time, datetime, timedelta
73from time import localtime
74from decimal import Decimal
75from uuid import UUID as Uuid
76from math import isnan, isinf
77from collections import namedtuple
78from functools import partial
79from re import compile as regex
80from json import loads as jsondecode, dumps as jsonencode
81
82try:
83    long
84except NameError:  # Python >= 3.0
85    long = int
86
87try:
88    unicode
89except NameError:  # Python >= 3.0
90    unicode = str
91
92try:
93    basestring
94except NameError:  # Python >= 3.0
95    basestring = (str, bytes)
96
97from collections import Iterable
98
99
100### Module Constants
101
102# compliant with DB API 2.0
103apilevel = '2.0'
104
105# module may be shared, but not connections
106threadsafety = 1
107
108# this module use extended python format codes
109paramstyle = 'pyformat'
110
111# shortcut methods have been excluded from DB API 2 and
112# are not recommended by the DB SIG, but they can be handy
113shortcutmethods = 1
114
115
116### Internal Type Handling
117
118try:
119    from inspect import signature
120except ImportError:  # Python < 3.3
121    from inspect import getargspec
122
123    def get_args(func):
124        return getargspec(func).args
125else:
126
127    def get_args(func):
128        return list(signature(func).parameters)
129
130try:
131    if datetime.strptime('+0100', '%z') is None:
132        raise ValueError
133except ValueError:  # Python < 3.2
134    timezones = None
135else:
136    # time zones used in Postgres timestamptz output
137    timezones = dict(CET='+0100', EET='+0200', EST='-0500',
138        GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
139        UCT='+0000', UTC='+0000', WET='+0000')
140
141
142def decimal_type(decimal_type=None):
143    """Get or set global type to be used for decimal values.
144
145    Note that connections cache cast functions. To be sure a global change
146    is picked up by a running connection, call con.type_cache.reset_typecast().
147    """
148    global Decimal
149    if decimal_type is not None:
150        Decimal = decimal_type
151        set_typecast('numeric', decimal_type)
152    return Decimal
153
154
155def cast_bool(value):
156    """Cast boolean value in database format to bool."""
157    if value:
158        return value[0] in ('t', 'T')
159
160
161def cast_money(value):
162    """Cast money value in database format to Decimal."""
163    if value:
164        value = value.replace('(', '-')
165        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
166
167
168def cast_int2vector(value):
169    """Cast an int2vector value."""
170    return [int(v) for v in value.split()]
171
172
173def cast_date(value, connection):
174    """Cast a date value."""
175    # The output format depends on the server setting DateStyle.  The default
176    # setting ISO and the setting for German are actually unambiguous.  The
177    # order of days and months in the other two settings is however ambiguous,
178    # so at least here we need to consult the setting to properly parse values.
179    if value == '-infinity':
180        return date.min
181    if value == 'infinity':
182        return date.max
183    value = value.split()
184    if value[-1] == 'BC':
185        return date.min
186    value = value[0]
187    if len(value) > 10:
188        return date.max
189    fmt = connection.date_format()
190    return datetime.strptime(value, fmt).date()
191
192
193def cast_time(value):
194    """Cast a time value."""
195    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
196    return datetime.strptime(value, fmt).time()
197
198
199_re_timezone = regex('(.*)([+-].*)')
200
201
202def cast_timetz(value):
203    """Cast a timetz value."""
204    tz = _re_timezone.match(value)
205    if tz:
206        value, tz = tz.groups()
207    else:
208        tz = '+0000'
209    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
210    if timezones:
211        if tz.startswith(('+', '-')):
212            if len(tz) < 5:
213                tz += '00'
214            else:
215                tz = tz.replace(':', '')
216        elif tz in timezones:
217            tz = timezones[tz]
218        else:
219            tz = '+0000'
220        value += tz
221        fmt += '%z'
222    return datetime.strptime(value, fmt).timetz()
223
224
225def cast_timestamp(value, connection):
226    """Cast a timestamp value."""
227    if value == '-infinity':
228        return datetime.min
229    if value == 'infinity':
230        return datetime.max
231    value = value.split()
232    if value[-1] == 'BC':
233        return datetime.min
234    fmt = connection.date_format()
235    if fmt.endswith('-%Y') and len(value) > 2:
236        value = value[1:5]
237        if len(value[3]) > 4:
238            return datetime.max
239        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
240            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
241    else:
242        if len(value[0]) > 10:
243            return datetime.max
244        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
245    return datetime.strptime(' '.join(value), ' '.join(fmt))
246
247
248def cast_timestamptz(value, connection):
249    """Cast a timestamptz value."""
250    if value == '-infinity':
251        return datetime.min
252    if value == 'infinity':
253        return datetime.max
254    value = value.split()
255    if value[-1] == 'BC':
256        return datetime.min
257    fmt = connection.date_format()
258    if fmt.endswith('-%Y') and len(value) > 2:
259        value = value[1:]
260        if len(value[3]) > 4:
261            return datetime.max
262        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
263            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
264        value, tz = value[:-1], value[-1]
265    else:
266        if fmt.startswith('%Y-'):
267            tz = _re_timezone.match(value[1])
268            if tz:
269                value[1], tz = tz.groups()
270            else:
271                tz = '+0000'
272        else:
273            value, tz = value[:-1], value[-1]
274        if len(value[0]) > 10:
275            return datetime.max
276        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
277    if timezones:
278        if tz.startswith(('+', '-')):
279            if len(tz) < 5:
280                tz += '00'
281            else:
282                tz = tz.replace(':', '')
283        elif tz in timezones:
284            tz = timezones[tz]
285        else:
286            tz = '+0000'
287        value.append(tz)
288        fmt.append('%z')
289    return datetime.strptime(' '.join(value), ' '.join(fmt))
290
291_re_interval_sql_standard = regex(
292    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
293    '(?:([+-]?[0-9]+)(?!:) ?)?'
294    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
295
296_re_interval_postgres = regex(
297    '(?:([+-]?[0-9]+) ?years? ?)?'
298    '(?:([+-]?[0-9]+) ?mons? ?)?'
299    '(?:([+-]?[0-9]+) ?days? ?)?'
300    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
301
302_re_interval_postgres_verbose = regex(
303    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
304    '(?:([+-]?[0-9]+) ?mons? ?)?'
305    '(?:([+-]?[0-9]+) ?days? ?)?'
306    '(?:([+-]?[0-9]+) ?hours? ?)?'
307    '(?:([+-]?[0-9]+) ?mins? ?)?'
308    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
309
310_re_interval_iso_8601 = regex(
311    'P(?:([+-]?[0-9]+)Y)?'
312    '(?:([+-]?[0-9]+)M)?'
313    '(?:([+-]?[0-9]+)D)?'
314    '(?:T(?:([+-]?[0-9]+)H)?'
315    '(?:([+-]?[0-9]+)M)?'
316    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
317
318
319def cast_interval(value):
320    """Cast an interval value."""
321    # The output format depends on the server setting IntervalStyle, but it's
322    # not necessary to consult this setting to parse it.  It's faster to just
323    # check all possible formats, and there is no ambiguity here.
324    m = _re_interval_iso_8601.match(value)
325    if m:
326        m = [d or '0' for d in m.groups()]
327        secs_ago = m.pop(5) == '-'
328        m = [int(d) for d in m]
329        years, mons, days, hours, mins, secs, usecs = m
330        if secs_ago:
331            secs = -secs
332            usecs = -usecs
333    else:
334        m = _re_interval_postgres_verbose.match(value)
335        if m:
336            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
337            secs_ago = m.pop(5) == '-'
338            m = [-int(d) for d in m] if ago else [int(d) for d in m]
339            years, mons, days, hours, mins, secs, usecs = m
340            if secs_ago:
341                secs = - secs
342                usecs = -usecs
343        else:
344            m = _re_interval_postgres.match(value)
345            if m and any(m.groups()):
346                m = [d or '0' for d in m.groups()]
347                hours_ago = m.pop(3) == '-'
348                m = [int(d) for d in m]
349                years, mons, days, hours, mins, secs, usecs = m
350                if hours_ago:
351                    hours = -hours
352                    mins = -mins
353                    secs = -secs
354                    usecs = -usecs
355            else:
356                m = _re_interval_sql_standard.match(value)
357                if m and any(m.groups()):
358                    m = [d or '0' for d in m.groups()]
359                    years_ago = m.pop(0) == '-'
360                    hours_ago = m.pop(3) == '-'
361                    m = [int(d) for d in m]
362                    years, mons, days, hours, mins, secs, usecs = m
363                    if years_ago:
364                        years = -years
365                        mons = -mons
366                    if hours_ago:
367                        hours = -hours
368                        mins = -mins
369                        secs = -secs
370                        usecs = -usecs
371                else:
372                    raise ValueError('Cannot parse interval: %s' % value)
373    days += 365 * years + 30 * mons
374    return timedelta(days=days, hours=hours, minutes=mins,
375        seconds=secs, microseconds=usecs)
376
377
378class Typecasts(dict):
379    """Dictionary mapping database types to typecast functions.
380
381    The cast functions get passed the string representation of a value in
382    the database which they need to convert to a Python object.  The
383    passed string will never be None since NULL values are already be
384    handled before the cast function is called.
385    """
386
387    # the default cast functions
388    # (str functions are ignored but have been added for faster access)
389    defaults = {'char': str, 'bpchar': str, 'name': str,
390        'text': str, 'varchar': str,
391        'bool': cast_bool, 'bytea': unescape_bytea,
392        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
393        'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode,
394        'float4': float, 'float8': float,
395        'numeric': Decimal, 'money': cast_money,
396        'date': cast_date, 'interval': cast_interval,
397        'time': cast_time, 'timetz': cast_timetz,
398        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
399        'int2vector': cast_int2vector, 'uuid': Uuid,
400        'anyarray': cast_array, 'record': cast_record}
401
402    connection = None  # will be set in local connection specific instances
403
404    def __missing__(self, typ):
405        """Create a cast function if it is not cached.
406
407        Note that this class never raises a KeyError,
408        but returns None when no special cast function exists.
409        """
410        if not isinstance(typ, str):
411            raise TypeError('Invalid type: %s' % typ)
412        cast = self.defaults.get(typ)
413        if cast:
414            # store default for faster access
415            cast = self._add_connection(cast)
416            self[typ] = cast
417        elif typ.startswith('_'):
418            # create array cast
419            base_cast = self[typ[1:]]
420            cast = self.create_array_cast(base_cast)
421            if base_cast:
422                # store only if base type exists
423                self[typ] = cast
424        return cast
425
426    @staticmethod
427    def _needs_connection(func):
428        """Check if a typecast function needs a connection argument."""
429        try:
430            args = get_args(func)
431        except (TypeError, ValueError):
432            return False
433        else:
434            return 'connection' in args[1:]
435
436    def _add_connection(self, cast):
437        """Add a connection argument to the typecast function if necessary."""
438        if not self.connection or not self._needs_connection(cast):
439            return cast
440        return partial(cast, connection=self.connection)
441
442    def get(self, typ, default=None):
443        """Get the typecast function for the given database type."""
444        return self[typ] or default
445
446    def set(self, typ, cast):
447        """Set a typecast function for the specified database type(s)."""
448        if isinstance(typ, basestring):
449            typ = [typ]
450        if cast is None:
451            for t in typ:
452                self.pop(t, None)
453                self.pop('_%s' % t, None)
454        else:
455            if not callable(cast):
456                raise TypeError("Cast parameter must be callable")
457            for t in typ:
458                self[t] = self._add_connection(cast)
459                self.pop('_%s' % t, None)
460
461    def reset(self, typ=None):
462        """Reset the typecasts for the specified type(s) to their defaults.
463
464        When no type is specified, all typecasts will be reset.
465        """
466        defaults = self.defaults
467        if typ is None:
468            self.clear()
469            self.update(defaults)
470        else:
471            if isinstance(typ, basestring):
472                typ = [typ]
473            for t in typ:
474                cast = defaults.get(t)
475                if cast:
476                    self[t] = self._add_connection(cast)
477                    t = '_%s' % t
478                    cast = defaults.get(t)
479                    if cast:
480                        self[t] = self._add_connection(cast)
481                    else:
482                        self.pop(t, None)
483                else:
484                    self.pop(t, None)
485                    self.pop('_%s' % t, None)
486
487    def create_array_cast(self, basecast):
488        """Create an array typecast for the given base cast."""
489        def cast(v):
490            return cast_array(v, basecast)
491        return cast
492
493    def create_record_cast(self, name, fields, casts):
494        """Create a named record typecast for the given fields and casts."""
495        record = namedtuple(name, fields)
496        def cast(v):
497            return record(*cast_record(v, casts))
498        return cast
499
500
501_typecasts = Typecasts()  # this is the global typecast dictionary
502
503
504def get_typecast(typ):
505    """Get the global typecast function for the given database type(s)."""
506    return _typecasts.get(typ)
507
508
509def set_typecast(typ, cast):
510    """Set a global typecast function for the given database type(s).
511
512    Note that connections cache cast functions. To be sure a global change
513    is picked up by a running connection, call con.type_cache.reset_typecast().
514    """
515    _typecasts.set(typ, cast)
516
517
518def reset_typecast(typ=None):
519    """Reset the global typecasts for the given type(s) to their default.
520
521    When no type is specified, all typecasts will be reset.
522
523    Note that connections cache cast functions. To be sure a global change
524    is picked up by a running connection, call con.type_cache.reset_typecast().
525    """
526    _typecasts.reset(typ)
527
528
529class LocalTypecasts(Typecasts):
530    """Map typecasts, including local composite types, to cast functions."""
531
532    defaults = _typecasts
533
534    connection = None  # will be set in a connection specific instance
535
536    def __missing__(self, typ):
537        """Create a cast function if it is not cached."""
538        if typ.startswith('_'):
539            base_cast = self[typ[1:]]
540            cast = self.create_array_cast(base_cast)
541            if base_cast:
542                self[typ] = cast
543        else:
544            cast = self.defaults.get(typ)
545            if cast:
546                cast = self._add_connection(cast)
547                self[typ] = cast
548            else:
549                fields = self.get_fields(typ)
550                if fields:
551                    casts = [self[field.type] for field in fields]
552                    fields = [field.name for field in fields]
553                    cast = self.create_record_cast(typ, fields, casts)
554                    self[typ] = cast
555        return cast
556
557    def get_fields(self, typ):
558        """Return the fields for the given record type.
559
560        This method will be replaced with a method that looks up the fields
561        using the type cache of the connection.
562        """
563        return []
564
565
566class TypeCode(str):
567    """Class representing the type_code used by the DB-API 2.0.
568
569    TypeCode objects are strings equal to the PostgreSQL type name,
570    but carry some additional information.
571    """
572
573    @classmethod
574    def create(cls, oid, name, len, type, category, delim, relid):
575        """Create a type code for a PostgreSQL data type."""
576        self = cls(name)
577        self.oid = oid
578        self.len = len
579        self.type = type
580        self.category = category
581        self.delim = delim
582        self.relid = relid
583        return self
584
585FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
586
587
588class TypeCache(dict):
589    """Cache for database types.
590
591    This cache maps type OIDs and names to TypeCode strings containing
592    important information on the associated database type.
593    """
594
595    def __init__(self, cnx):
596        """Initialize type cache for connection."""
597        super(TypeCache, self).__init__()
598        self._escape_string = cnx.escape_string
599        self._src = cnx.source()
600        self._typecasts = LocalTypecasts()
601        self._typecasts.get_fields = self.get_fields
602        self._typecasts.connection = cnx
603
604    def __missing__(self, key):
605        """Get the type info from the database if it is not cached."""
606        if isinstance(key, int):
607            oid = key
608        else:
609            if '.' not in key and '"' not in key:
610                key = '"%s"' % key
611            oid = "'%s'::regtype" % self._escape_string(key)
612        try:
613            self._src.execute("SELECT oid, typname,"
614                 " typlen, typtype, typcategory, typdelim, typrelid"
615                " FROM pg_type WHERE oid=%s" % oid)
616        except ProgrammingError:
617            res = None
618        else:
619            res = self._src.fetch(1)
620        if not res:
621            raise KeyError('Type %s could not be found' % key)
622        res = res[0]
623        type_code = TypeCode.create(int(res[0]), res[1],
624            int(res[2]), res[3], res[4], res[5], int(res[6]))
625        self[type_code.oid] = self[str(type_code)] = type_code
626        return type_code
627
628    def get(self, key, default=None):
629        """Get the type even if it is not cached."""
630        try:
631            return self[key]
632        except KeyError:
633            return default
634
635    def get_fields(self, typ):
636        """Get the names and types of the fields of composite types."""
637        if not isinstance(typ, TypeCode):
638            typ = self.get(typ)
639            if not typ:
640                return None
641        if not typ.relid:
642            return None  # this type is not composite
643        self._src.execute("SELECT attname, atttypid"
644            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
645            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
646        return [FieldInfo(name, self.get(int(oid)))
647            for name, oid in self._src.fetch(-1)]
648
649    def get_typecast(self, typ):
650        """Get the typecast function for the given database type."""
651        return self._typecasts.get(typ)
652
653    def set_typecast(self, typ, cast):
654        """Set a typecast function for the specified database type(s)."""
655        self._typecasts.set(typ, cast)
656
657    def reset_typecast(self, typ=None):
658        """Reset the typecast function for the specified database type(s)."""
659        self._typecasts.reset(typ)
660
661    def typecast(self, value, typ):
662        """Cast the given value according to the given database type."""
663        if value is None:
664            # for NULL values, no typecast is necessary
665            return None
666        cast = self.get_typecast(typ)
667        if not cast or cast is str:
668            # no typecast is necessary
669            return value
670        return cast(value)
671
672
673class _quotedict(dict):
674    """Dictionary with auto quoting of its items.
675
676    The quote attribute must be set to the desired quote function.
677    """
678
679    def __getitem__(self, key):
680        return self.quote(super(_quotedict, self).__getitem__(key))
681
682
683### Error messages
684
685def _db_error(msg, cls=DatabaseError):
686    """Return DatabaseError with empty sqlstate attribute."""
687    error = cls(msg)
688    error.sqlstate = None
689    return error
690
691
692def _op_error(msg):
693    """Return OperationalError."""
694    return _db_error(msg, OperationalError)
695
696
697### Cursor Object
698
699class Cursor(object):
700    """Cursor object."""
701
702    def __init__(self, dbcnx):
703        """Create a cursor object for the database connection."""
704        self.connection = self._dbcnx = dbcnx
705        self._cnx = dbcnx._cnx
706        self.type_cache = dbcnx.type_cache
707        self._src = self._cnx.source()
708        # the official attribute for describing the result columns
709        self._description = None
710        if self.row_factory is Cursor.row_factory:
711            # the row factory needs to be determined dynamically
712            self.row_factory = None
713        else:
714            self.build_row_factory = None
715        self.rowcount = -1
716        self.arraysize = 1
717        self.lastrowid = None
718
719    def __iter__(self):
720        """Make cursor compatible to the iteration protocol."""
721        return self
722
723    def __enter__(self):
724        """Enter the runtime context for the cursor object."""
725        return self
726
727    def __exit__(self, et, ev, tb):
728        """Exit the runtime context for the cursor object."""
729        self.close()
730
731    def _quote(self, value):
732        """Quote value depending on its type."""
733        if value is None:
734            return 'NULL'
735        if isinstance(value, (Hstore, Json)):
736            value = str(value)
737        if isinstance(value, basestring):
738            if isinstance(value, Binary):
739                value = self._cnx.escape_bytea(value)
740                if bytes is not str:  # Python >= 3.0
741                    value = value.decode('ascii')
742            else:
743                value = self._cnx.escape_string(value)
744            return "'%s'" % value
745        if isinstance(value, float):
746            if isinf(value):
747                return "'-Infinity'" if value < 0 else "'Infinity'"
748            if isnan(value):
749                return "'NaN'"
750            return value
751        if isinstance(value, (int, long, Decimal, Literal)):
752            return value
753        if isinstance(value, datetime):
754            if value.tzinfo:
755                return "'%s'::timestamptz" % value
756            return "'%s'::timestamp" % value
757        if isinstance(value, date):
758            return "'%s'::date" % value
759        if isinstance(value, time):
760            if value.tzinfo:
761                return "'%s'::timetz" % value
762            return "'%s'::time" % value
763        if isinstance(value, timedelta):
764            return "'%s'::interval" % value
765        if isinstance(value, Uuid):
766            return "'%s'::uuid" % value
767        if isinstance(value, list):
768            # Quote value as an ARRAY constructor. This is better than using
769            # an array literal because it carries the information that this is
770            # an array and not a string.  One issue with this syntax is that
771            # you need to add an explicit typecast when passing empty arrays.
772            # The ARRAY keyword is actually only necessary at the top level.
773            if not value:  # exception for empty array
774                return "'{}'"
775            q = self._quote
776            return 'ARRAY[%s]' % ','.join(str(q(v)) for v in value)
777        if isinstance(value, tuple):
778            # Quote as a ROW constructor.  This is better than using a record
779            # literal because it carries the information that this is a record
780            # and not a string.  We don't use the keyword ROW in order to make
781            # this usable with the IN syntax as well.  It is only necessary
782            # when the records has a single column which is not really useful.
783            q = self._quote
784            return '(%s)' % ','.join(str(q(v)) for v in value)
785        try:
786            value = value.__pg_repr__()
787        except AttributeError:
788            raise InterfaceError(
789                'Do not know how to adapt type %s' % type(value))
790        if isinstance(value, (tuple, list)):
791            value = self._quote(value)
792        return value
793
794    def _quoteparams(self, string, parameters):
795        """Quote parameters.
796
797        This function works for both mappings and sequences.
798
799        The function should be used even when there are no parameters,
800        so that we have a consistent behavior regarding percent signs.
801        """
802        if not parameters:
803            try:
804                return string % ()  # unescape literal quotes if possible
805            except (TypeError, ValueError):
806                return string  # silently accept unescaped quotes
807        if isinstance(parameters, dict):
808            parameters = _quotedict(parameters)
809            parameters.quote = self._quote
810        else:
811            parameters = tuple(map(self._quote, parameters))
812        return string % parameters
813
814    def _make_description(self, info):
815        """Make the description tuple for the given field info."""
816        name, typ, size, mod = info[1:]
817        type_code = self.type_cache[typ]
818        if mod > 0:
819            mod -= 4
820        if type_code == 'numeric':
821            precision, scale = mod >> 16, mod & 0xffff
822            size = precision
823        else:
824            if not size:
825                size = type_code.size
826            if size == -1:
827                size = mod
828            precision = scale = None
829        return CursorDescription(name, type_code,
830            None, size, precision, scale, None)
831
832    @property
833    def description(self):
834        """Read-only attribute describing the result columns."""
835        descr = self._description
836        if self._description is True:
837            make = self._make_description
838            descr = [make(info) for info in self._src.listinfo()]
839            self._description = descr
840        return descr
841
842    @property
843    def colnames(self):
844        """Unofficial convenience method for getting the column names."""
845        return [d[0] for d in self.description]
846
847    @property
848    def coltypes(self):
849        """Unofficial convenience method for getting the column types."""
850        return [d[1] for d in self.description]
851
852    def close(self):
853        """Close the cursor object."""
854        self._src.close()
855
856    def execute(self, operation, parameters=None):
857        """Prepare and execute a database operation (query or command)."""
858        # The parameters may also be specified as list of tuples to e.g.
859        # insert multiple rows in a single operation, but this kind of
860        # usage is deprecated.  We make several plausibility checks because
861        # tuples can also be passed with the meaning of ROW constructors.
862        if (parameters and isinstance(parameters, list)
863                and len(parameters) > 1
864                and all(isinstance(p, tuple) for p in parameters)
865                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
866            return self.executemany(operation, parameters)
867        else:
868            # not a list of tuples
869            return self.executemany(operation, [parameters])
870
871    def executemany(self, operation, seq_of_parameters):
872        """Prepare operation and execute it against a parameter sequence."""
873        if not seq_of_parameters:
874            # don't do anything without parameters
875            return
876        self._description = None
877        self.rowcount = -1
878        # first try to execute all queries
879        rowcount = 0
880        sql = "BEGIN"
881        try:
882            if not self._dbcnx._tnx:
883                try:
884                    self._cnx.source().execute(sql)
885                except DatabaseError:
886                    raise  # database provides error message
887                except Exception:
888                    raise _op_error("Can't start transaction")
889                self._dbcnx._tnx = True
890            for parameters in seq_of_parameters:
891                sql = operation
892                sql = self._quoteparams(sql, parameters)
893                rows = self._src.execute(sql)
894                if rows:  # true if not DML
895                    rowcount += rows
896                else:
897                    self.rowcount = -1
898        except DatabaseError:
899            raise  # database provides error message
900        except Error as err:
901            raise _db_error(
902                "Error in '%s': '%s' " % (sql, err), InterfaceError)
903        except Exception as err:
904            raise _op_error("Internal error in '%s': %s" % (sql, err))
905        # then initialize result raw count and description
906        if self._src.resulttype == RESULT_DQL:
907            self._description = True  # fetch on demand
908            self.rowcount = self._src.ntuples
909            self.lastrowid = None
910            if self.build_row_factory:
911                self.row_factory = self.build_row_factory()
912        else:
913            self.rowcount = rowcount
914            self.lastrowid = self._src.oidstatus()
915        # return the cursor object, so you can write statements such as
916        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
917        return self
918
919    def fetchone(self):
920        """Fetch the next row of a query result set."""
921        res = self.fetchmany(1, False)
922        try:
923            return res[0]
924        except IndexError:
925            return None
926
927    def fetchall(self):
928        """Fetch all (remaining) rows of a query result."""
929        return self.fetchmany(-1, False)
930
931    def fetchmany(self, size=None, keep=False):
932        """Fetch the next set of rows of a query result.
933
934        The number of rows to fetch per call is specified by the
935        size parameter. If it is not given, the cursor's arraysize
936        determines the number of rows to be fetched. If you set
937        the keep parameter to true, this is kept as new arraysize.
938        """
939        if size is None:
940            size = self.arraysize
941        if keep:
942            self.arraysize = size
943        try:
944            result = self._src.fetch(size)
945        except DatabaseError:
946            raise
947        except Error as err:
948            raise _db_error(str(err))
949        typecast = self.type_cache.typecast
950        return [self.row_factory([typecast(value, typ)
951            for typ, value in zip(self.coltypes, row)]) for row in result]
952
953    def callproc(self, procname, parameters=None):
954        """Call a stored database procedure with the given name.
955
956        The sequence of parameters must contain one entry for each input
957        argument that the procedure expects. The result of the call is the
958        same as this input sequence; replacement of output and input/output
959        parameters in the return value is currently not supported.
960
961        The procedure may also provide a result set as output. These can be
962        requested through the standard fetch methods of the cursor.
963        """
964        n = parameters and len(parameters) or 0
965        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
966        self.execute(query, parameters)
967        return parameters
968
969    def copy_from(self, stream, table,
970            format=None, sep=None, null=None, size=None, columns=None):
971        """Copy data from an input stream to the specified table.
972
973        The input stream can be a file-like object with a read() method or
974        it can also be an iterable returning a row or multiple rows of input
975        on each iteration.
976
977        The format must be text, csv or binary. The sep option sets the
978        column separator (delimiter) used in the non binary formats.
979        The null option sets the textual representation of NULL in the input.
980
981        The size option sets the size of the buffer used when reading data
982        from file-like objects.
983
984        The copy operation can be restricted to a subset of columns. If no
985        columns are specified, all of them will be copied.
986        """
987        binary_format = format == 'binary'
988        try:
989            read = stream.read
990        except AttributeError:
991            if size:
992                raise ValueError("Size must only be set for file-like objects")
993            if binary_format:
994                input_type = bytes
995                type_name = 'byte strings'
996            else:
997                input_type = basestring
998                type_name = 'strings'
999
1000            if isinstance(stream, basestring):
1001                if not isinstance(stream, input_type):
1002                    raise ValueError("The input must be %s" % type_name)
1003                if not binary_format:
1004                    if isinstance(stream, str):
1005                        if not stream.endswith('\n'):
1006                            stream += '\n'
1007                    else:
1008                        if not stream.endswith(b'\n'):
1009                            stream += b'\n'
1010
1011                def chunks():
1012                    yield stream
1013
1014            elif isinstance(stream, Iterable):
1015
1016                def chunks():
1017                    for chunk in stream:
1018                        if not isinstance(chunk, input_type):
1019                            raise ValueError(
1020                                "Input stream must consist of %s" % type_name)
1021                        if isinstance(chunk, str):
1022                            if not chunk.endswith('\n'):
1023                                chunk += '\n'
1024                        else:
1025                            if not chunk.endswith(b'\n'):
1026                                chunk += b'\n'
1027                        yield chunk
1028
1029            else:
1030                raise TypeError("Need an input stream to copy from")
1031        else:
1032            if size is None:
1033                size = 8192
1034            elif not isinstance(size, int):
1035                raise TypeError("The size option must be an integer")
1036            if size > 0:
1037
1038                def chunks():
1039                    while True:
1040                        buffer = read(size)
1041                        yield buffer
1042                        if not buffer or len(buffer) < size:
1043                            break
1044
1045            else:
1046
1047                def chunks():
1048                    yield read()
1049
1050        if not table or not isinstance(table, basestring):
1051            raise TypeError("Need a table to copy to")
1052        if table.lower().startswith('select'):
1053                raise ValueError("Must specify a table, not a query")
1054        else:
1055            table = '"%s"' % (table,)
1056        operation = ['copy %s' % (table,)]
1057        options = []
1058        params = []
1059        if format is not None:
1060            if not isinstance(format, basestring):
1061                raise TypeError("The frmat option must be be a string")
1062            if format not in ('text', 'csv', 'binary'):
1063                raise ValueError("Invalid format")
1064            options.append('format %s' % (format,))
1065        if sep is not None:
1066            if not isinstance(sep, basestring):
1067                raise TypeError("The sep option must be a string")
1068            if format == 'binary':
1069                raise ValueError(
1070                    "The sep option is not allowed with binary format")
1071            if len(sep) != 1:
1072                raise ValueError(
1073                    "The sep option must be a single one-byte character")
1074            options.append('delimiter %s')
1075            params.append(sep)
1076        if null is not None:
1077            if not isinstance(null, basestring):
1078                raise TypeError("The null option must be a string")
1079            options.append('null %s')
1080            params.append(null)
1081        if columns:
1082            if not isinstance(columns, basestring):
1083                columns = ','.join('"%s"' % (col,) for col in columns)
1084            operation.append('(%s)' % (columns,))
1085        operation.append("from stdin")
1086        if options:
1087            operation.append('(%s)' % ','.join(options))
1088        operation = ' '.join(operation)
1089
1090        putdata = self._src.putdata
1091        self.execute(operation, params)
1092
1093        try:
1094            for chunk in chunks():
1095                putdata(chunk)
1096        except BaseException as error:
1097            self.rowcount = -1
1098            # the following call will re-raise the error
1099            putdata(error)
1100        else:
1101            self.rowcount = putdata(None)
1102
1103        # return the cursor object, so you can chain operations
1104        return self
1105
1106    def copy_to(self, stream, table,
1107            format=None, sep=None, null=None, decode=None, columns=None):
1108        """Copy data from the specified table to an output stream.
1109
1110        The output stream can be a file-like object with a write() method or
1111        it can also be None, in which case the method will return a generator
1112        yielding a row on each iteration.
1113
1114        Output will be returned as byte strings unless you set decode to true.
1115
1116        Note that you can also use a select query instead of the table name.
1117
1118        The format must be text, csv or binary. The sep option sets the
1119        column separator (delimiter) used in the non binary formats.
1120        The null option sets the textual representation of NULL in the output.
1121
1122        The copy operation can be restricted to a subset of columns. If no
1123        columns are specified, all of them will be copied.
1124        """
1125        binary_format = format == 'binary'
1126        if stream is not None:
1127            try:
1128                write = stream.write
1129            except AttributeError:
1130                raise TypeError("Need an output stream to copy to")
1131        if not table or not isinstance(table, basestring):
1132            raise TypeError("Need a table to copy to")
1133        if table.lower().startswith('select'):
1134            if columns:
1135                raise ValueError("Columns must be specified in the query")
1136            table = '(%s)' % (table,)
1137        else:
1138            table = '"%s"' % (table,)
1139        operation = ['copy %s' % (table,)]
1140        options = []
1141        params = []
1142        if format is not None:
1143            if not isinstance(format, basestring):
1144                raise TypeError("The format option must be a string")
1145            if format not in ('text', 'csv', 'binary'):
1146                raise ValueError("Invalid format")
1147            options.append('format %s' % (format,))
1148        if sep is not None:
1149            if not isinstance(sep, basestring):
1150                raise TypeError("The sep option must be a string")
1151            if binary_format:
1152                raise ValueError(
1153                    "The sep option is not allowed with binary format")
1154            if len(sep) != 1:
1155                raise ValueError(
1156                    "The sep option must be a single one-byte character")
1157            options.append('delimiter %s')
1158            params.append(sep)
1159        if null is not None:
1160            if not isinstance(null, basestring):
1161                raise TypeError("The null option must be a string")
1162            options.append('null %s')
1163            params.append(null)
1164        if decode is None:
1165            if format == 'binary':
1166                decode = False
1167            else:
1168                decode = str is unicode
1169        else:
1170            if not isinstance(decode, (int, bool)):
1171                raise TypeError("The decode option must be a boolean")
1172            if decode and binary_format:
1173                raise ValueError(
1174                    "The decode option is not allowed with binary format")
1175        if columns:
1176            if not isinstance(columns, basestring):
1177                columns = ','.join('"%s"' % (col,) for col in columns)
1178            operation.append('(%s)' % (columns,))
1179
1180        operation.append("to stdout")
1181        if options:
1182            operation.append('(%s)' % ','.join(options))
1183        operation = ' '.join(operation)
1184
1185        getdata = self._src.getdata
1186        self.execute(operation, params)
1187
1188        def copy():
1189            self.rowcount = 0
1190            while True:
1191                row = getdata(decode)
1192                if isinstance(row, int):
1193                    if self.rowcount != row:
1194                        self.rowcount = row
1195                    break
1196                self.rowcount += 1
1197                yield row
1198
1199        if stream is None:
1200            # no input stream, return the generator
1201            return copy()
1202
1203        # write the rows to the file-like input stream
1204        for row in copy():
1205            write(row)
1206
1207        # return the cursor object, so you can chain operations
1208        return self
1209
1210    def __next__(self):
1211        """Return the next row (support for the iteration protocol)."""
1212        res = self.fetchone()
1213        if res is None:
1214            raise StopIteration
1215        return res
1216
1217    # Note that since Python 2.6 the iterator protocol uses __next()__
1218    # instead of next(), we keep it only for backward compatibility of pgdb.
1219    next = __next__
1220
1221    @staticmethod
1222    def nextset():
1223        """Not supported."""
1224        raise NotSupportedError("The nextset() method is not supported")
1225
1226    @staticmethod
1227    def setinputsizes(sizes):
1228        """Not supported."""
1229        pass  # unsupported, but silently passed
1230
1231    @staticmethod
1232    def setoutputsize(size, column=0):
1233        """Not supported."""
1234        pass  # unsupported, but silently passed
1235
1236    @staticmethod
1237    def row_factory(row):
1238        """Process rows before they are returned.
1239
1240        You can overwrite this statically with a custom row factory, or
1241        you can build a row factory dynamically with build_row_factory().
1242
1243        For example, you can create a Cursor class that returns rows as
1244        Python dictionaries like this:
1245
1246            class DictCursor(pgdb.Cursor):
1247
1248                def row_factory(self, row):
1249                    return {desc[0]: value
1250                        for desc, value in zip(self.description, row)}
1251
1252            cur = DictCursor(con)  # get one DictCursor instance or
1253            con.cursor_type = DictCursor  # always use DictCursor instances
1254        """
1255        raise NotImplementedError
1256
1257    def build_row_factory(self):
1258        """Build a row factory based on the current description.
1259
1260        This implementation builds a row factory for creating named tuples.
1261        You can overwrite this method if you want to dynamically create
1262        different row factories whenever the column description changes.
1263        """
1264        colnames = self.colnames
1265        if colnames:
1266            try:
1267                try:
1268                    return namedtuple('Row', colnames, rename=True)._make
1269                except TypeError:  # Python 2.6 and 3.0 do not support rename
1270                    colnames = [v if v.isalnum() else 'column_%d' % n
1271                             for n, v in enumerate(colnames)]
1272                    return namedtuple('Row', colnames)._make
1273            except ValueError:  # there is still a problem with the field names
1274                colnames = ['column_%d' % n for n in range(len(colnames))]
1275                return namedtuple('Row', colnames)._make
1276
1277
1278CursorDescription = namedtuple('CursorDescription',
1279    ['name', 'type_code', 'display_size', 'internal_size',
1280     'precision', 'scale', 'null_ok'])
1281
1282
1283### Connection Objects
1284
1285class Connection(object):
1286    """Connection object."""
1287
1288    # expose the exceptions as attributes on the connection object
1289    Error = Error
1290    Warning = Warning
1291    InterfaceError = InterfaceError
1292    DatabaseError = DatabaseError
1293    InternalError = InternalError
1294    OperationalError = OperationalError
1295    ProgrammingError = ProgrammingError
1296    IntegrityError = IntegrityError
1297    DataError = DataError
1298    NotSupportedError = NotSupportedError
1299
1300    def __init__(self, cnx):
1301        """Create a database connection object."""
1302        self._cnx = cnx  # connection
1303        self._tnx = False  # transaction state
1304        self.type_cache = TypeCache(cnx)
1305        self.cursor_type = Cursor
1306        try:
1307            self._cnx.source()
1308        except Exception:
1309            raise _op_error("Invalid connection")
1310
1311    def __enter__(self):
1312        """Enter the runtime context for the connection object.
1313
1314        The runtime context can be used for running transactions.
1315        """
1316        return self
1317
1318    def __exit__(self, et, ev, tb):
1319        """Exit the runtime context for the connection object.
1320
1321        This does not close the connection, but it ends a transaction.
1322        """
1323        if et is None and ev is None and tb is None:
1324            self.commit()
1325        else:
1326            self.rollback()
1327
1328    def close(self):
1329        """Close the connection object."""
1330        if self._cnx:
1331            if self._tnx:
1332                try:
1333                    self.rollback()
1334                except DatabaseError:
1335                    pass
1336            self._cnx.close()
1337            self._cnx = None
1338        else:
1339            raise _op_error("Connection has been closed")
1340
1341    @property
1342    def closed(self):
1343        """Check whether the connection has been closed or is broken."""
1344        try:
1345            return not self._cnx or self._cnx.status != 1
1346        except TypeError:
1347            return True
1348
1349    def commit(self):
1350        """Commit any pending transaction to the database."""
1351        if self._cnx:
1352            if self._tnx:
1353                self._tnx = False
1354                try:
1355                    self._cnx.source().execute("COMMIT")
1356                except DatabaseError:
1357                    raise
1358                except Exception:
1359                    raise _op_error("Can't commit")
1360        else:
1361            raise _op_error("Connection has been closed")
1362
1363    def rollback(self):
1364        """Roll back to the start of any pending transaction."""
1365        if self._cnx:
1366            if self._tnx:
1367                self._tnx = False
1368                try:
1369                    self._cnx.source().execute("ROLLBACK")
1370                except DatabaseError:
1371                    raise
1372                except Exception:
1373                    raise _op_error("Can't rollback")
1374        else:
1375            raise _op_error("Connection has been closed")
1376
1377    def cursor(self):
1378        """Return a new cursor object using the connection."""
1379        if self._cnx:
1380            try:
1381                return self.cursor_type(self)
1382            except Exception:
1383                raise _op_error("Invalid connection")
1384        else:
1385            raise _op_error("Connection has been closed")
1386
1387    if shortcutmethods:  # otherwise do not implement and document this
1388
1389        def execute(self, operation, params=None):
1390            """Shortcut method to run an operation on an implicit cursor."""
1391            cursor = self.cursor()
1392            cursor.execute(operation, params)
1393            return cursor
1394
1395        def executemany(self, operation, param_seq):
1396            """Shortcut method to run an operation against a sequence."""
1397            cursor = self.cursor()
1398            cursor.executemany(operation, param_seq)
1399            return cursor
1400
1401
1402### Module Interface
1403
1404_connect = connect
1405
1406def connect(dsn=None,
1407        user=None, password=None,
1408        host=None, database=None):
1409    """Connect to a database."""
1410    # first get params from DSN
1411    dbport = -1
1412    dbhost = ""
1413    dbbase = ""
1414    dbuser = ""
1415    dbpasswd = ""
1416    dbopt = ""
1417    try:
1418        params = dsn.split(":")
1419        dbhost = params[0]
1420        dbbase = params[1]
1421        dbuser = params[2]
1422        dbpasswd = params[3]
1423        dbopt = params[4]
1424    except (AttributeError, IndexError, TypeError):
1425        pass
1426
1427    # override if necessary
1428    if user is not None:
1429        dbuser = user
1430    if password is not None:
1431        dbpasswd = password
1432    if database is not None:
1433        dbbase = database
1434    if host is not None:
1435        try:
1436            params = host.split(":")
1437            dbhost = params[0]
1438            dbport = int(params[1])
1439        except (AttributeError, IndexError, TypeError, ValueError):
1440            pass
1441
1442    # empty host is localhost
1443    if dbhost == "":
1444        dbhost = None
1445    if dbuser == "":
1446        dbuser = None
1447
1448    # open the connection
1449    cnx = _connect(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
1450    return Connection(cnx)
1451
1452
1453### Types Handling
1454
1455class Type(frozenset):
1456    """Type class for a couple of PostgreSQL data types.
1457
1458    PostgreSQL is object-oriented: types are dynamic.
1459    We must thus use type names as internal type codes.
1460    """
1461
1462    def __new__(cls, values):
1463        if isinstance(values, basestring):
1464            values = values.split()
1465        return super(Type, cls).__new__(cls, values)
1466
1467    def __eq__(self, other):
1468        if isinstance(other, basestring):
1469            if other.startswith('_'):
1470                other = other[1:]
1471            return other in self
1472        else:
1473            return super(Type, self).__eq__(other)
1474
1475    def __ne__(self, other):
1476        if isinstance(other, basestring):
1477            if other.startswith('_'):
1478                other = other[1:]
1479            return other not in self
1480        else:
1481            return super(Type, self).__ne__(other)
1482
1483
1484class ArrayType:
1485    """Type class for PostgreSQL array types."""
1486
1487    def __eq__(self, other):
1488        if isinstance(other, basestring):
1489            return other.startswith('_')
1490        else:
1491            return isinstance(other, ArrayType)
1492
1493    def __ne__(self, other):
1494        if isinstance(other, basestring):
1495            return not other.startswith('_')
1496        else:
1497            return not isinstance(other, ArrayType)
1498
1499
1500class RecordType:
1501    """Type class for PostgreSQL record types."""
1502
1503    def __eq__(self, other):
1504        if isinstance(other, TypeCode):
1505            return other.type == 'c'
1506        elif isinstance(other, basestring):
1507            return other == 'record'
1508        else:
1509            return isinstance(other, RecordType)
1510
1511    def __ne__(self, other):
1512        if isinstance(other, TypeCode):
1513            return other.type != 'c'
1514        elif isinstance(other, basestring):
1515            return other != 'record'
1516        else:
1517            return not isinstance(other, RecordType)
1518
1519
1520# Mandatory type objects defined by DB-API 2 specs:
1521
1522STRING = Type('char bpchar name text varchar')
1523BINARY = Type('bytea')
1524NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1525DATETIME = Type('date time timetz timestamp timestamptz interval'
1526    ' abstime reltime')  # these are very old
1527ROWID = Type('oid')
1528
1529
1530# Additional type objects (more specific):
1531
1532BOOL = Type('bool')
1533SMALLINT = Type('int2')
1534INTEGER = Type('int2 int4 int8 serial')
1535LONG = Type('int8')
1536FLOAT = Type('float4 float8')
1537NUMERIC = Type('numeric')
1538MONEY = Type('money')
1539DATE = Type('date')
1540TIME = Type('time timetz')
1541TIMESTAMP = Type('timestamp timestamptz')
1542INTERVAL = Type('interval')
1543UUID = Type('uuid')
1544HSTORE = Type('hstore')
1545JSON = Type('json jsonb')
1546
1547# Type object for arrays (also equate to their base types):
1548
1549ARRAY = ArrayType()
1550
1551# Type object for records (encompassing all composite types):
1552
1553RECORD = RecordType()
1554
1555
1556# Mandatory type helpers defined by DB-API 2 specs:
1557
1558def Date(year, month, day):
1559    """Construct an object holding a date value."""
1560    return date(year, month, day)
1561
1562
1563def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None):
1564    """Construct an object holding a time value."""
1565    return time(hour, minute, second, microsecond, tzinfo)
1566
1567
1568def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0,
1569        tzinfo=None):
1570    """Construct an object holding a time stamp value."""
1571    return datetime(year, month, day, hour, minute, second, microsecond, tzinfo)
1572
1573
1574def DateFromTicks(ticks):
1575    """Construct an object holding a date value from the given ticks value."""
1576    return Date(*localtime(ticks)[:3])
1577
1578
1579def TimeFromTicks(ticks):
1580    """Construct an object holding a time value from the given ticks value."""
1581    return Time(*localtime(ticks)[3:6])
1582
1583
1584def TimestampFromTicks(ticks):
1585    """Construct an object holding a time stamp from the given ticks value."""
1586    return Timestamp(*localtime(ticks)[:6])
1587
1588
1589class Binary(bytes):
1590    """Construct an object capable of holding a binary (long) string value."""
1591
1592
1593# Additional type helpers for PyGreSQL:
1594
1595def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0):
1596    """Construct an object holding a time inverval value."""
1597    return timedelta(days, hours=hours, minutes=minutes, seconds=seconds,
1598        microseconds=microseconds)
1599
1600
1601Uuid = Uuid  # Construct an object holding a UUID value
1602
1603
1604class Hstore(dict):
1605    """Wrapper class for marking hstore values."""
1606
1607    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
1608    _re_escape = regex(r'(["\\])')
1609
1610    @classmethod
1611    def _quote(cls, s):
1612        if s is None:
1613            return 'NULL'
1614        if not s:
1615            return '""'
1616        quote = cls._re_quote.search(s)
1617        s = cls._re_escape.sub(r'\\\1', s)
1618        if quote:
1619            s = '"%s"' % s
1620        return s
1621
1622    def __str__(self):
1623        q = self._quote
1624        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
1625
1626
1627class Json:
1628    """Construct a wrapper for holding an object serializable to JSON."""
1629
1630    def __init__(self, obj, encode=None):
1631        self.obj = obj
1632        self.encode = encode or jsonencode
1633
1634    def __str__(self):
1635        obj = self.obj
1636        if isinstance(obj, basestring):
1637            return obj
1638        return self.encode(obj)
1639
1640
1641class Literal:
1642    """Construct a wrapper for holding a literal SQL string."""
1643
1644    def __init__(self, sql):
1645        self.sql = sql
1646
1647    def __str__(self):
1648        return self.sql
1649
1650    __pg_repr__ = __str__
1651
1652# If run as script, print some information:
1653
1654if __name__ == '__main__':
1655    print('PyGreSQL version', version)
1656    print('')
1657    print(__doc__)
Note: See TracBrowser for help on using the repository browser.