source: trunk/pgdb.py @ 847

Last change on this file since 847 was 847, checked in by cito, 4 years ago

Fix issue with arrays and records containing unicode

Arrays and records with non-ascii unicode elements
did not work properly in Python 2.

  • Property svn:keywords set to Id
File size: 54.4 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 847 2016-02-09 10:51:46Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70__version__ = version
71
72from datetime import date, time, datetime, timedelta
73from time import localtime
74from decimal import Decimal
75from uuid import UUID as Uuid
76from math import isnan, isinf
77from collections import namedtuple
78from functools import partial
79from re import compile as regex
80from json import loads as jsondecode, dumps as jsonencode
81
82try:
83    long
84except NameError:  # Python >= 3.0
85    long = int
86
87try:
88    unicode
89except NameError:  # Python >= 3.0
90    unicode = str
91
92try:
93    basestring
94except NameError:  # Python >= 3.0
95    basestring = (str, bytes)
96
97from collections import Iterable
98
99
100### Module Constants
101
102# compliant with DB API 2.0
103apilevel = '2.0'
104
105# module may be shared, but not connections
106threadsafety = 1
107
108# this module use extended python format codes
109paramstyle = 'pyformat'
110
111# shortcut methods have been excluded from DB API 2 and
112# are not recommended by the DB SIG, but they can be handy
113shortcutmethods = 1
114
115
116### Internal Type Handling
117
118try:
119    from inspect import signature
120except ImportError:  # Python < 3.3
121    from inspect import getargspec
122
123    def get_args(func):
124        return getargspec(func).args
125else:
126
127    def get_args(func):
128        return list(signature(func).parameters)
129
130try:
131    if datetime.strptime('+0100', '%z') is None:
132        raise ValueError
133except ValueError:  # Python < 3.2
134    timezones = None
135else:
136    # time zones used in Postgres timestamptz output
137    timezones = dict(CET='+0100', EET='+0200', EST='-0500',
138        GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
139        UCT='+0000', UTC='+0000', WET='+0000')
140
141
142def decimal_type(decimal_type=None):
143    """Get or set global type to be used for decimal values.
144
145    Note that connections cache cast functions. To be sure a global change
146    is picked up by a running connection, call con.type_cache.reset_typecast().
147    """
148    global Decimal
149    if decimal_type is not None:
150        Decimal = decimal_type
151        set_typecast('numeric', decimal_type)
152    return Decimal
153
154
155def cast_bool(value):
156    """Cast boolean value in database format to bool."""
157    if value:
158        return value[0] in ('t', 'T')
159
160
161def cast_money(value):
162    """Cast money value in database format to Decimal."""
163    if value:
164        value = value.replace('(', '-')
165        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
166
167
168def cast_int2vector(value):
169    """Cast an int2vector value."""
170    return [int(v) for v in value.split()]
171
172
173def cast_date(value, connection):
174    """Cast a date value."""
175    # The output format depends on the server setting DateStyle.  The default
176    # setting ISO and the setting for German are actually unambiguous.  The
177    # order of days and months in the other two settings is however ambiguous,
178    # so at least here we need to consult the setting to properly parse values.
179    if value == '-infinity':
180        return date.min
181    if value == 'infinity':
182        return date.max
183    value = value.split()
184    if value[-1] == 'BC':
185        return date.min
186    value = value[0]
187    if len(value) > 10:
188        return date.max
189    fmt = connection.date_format()
190    return datetime.strptime(value, fmt).date()
191
192
193def cast_time(value):
194    """Cast a time value."""
195    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
196    return datetime.strptime(value, fmt).time()
197
198
199_re_timezone = regex('(.*)([+-].*)')
200
201
202def cast_timetz(value):
203    """Cast a timetz value."""
204    tz = _re_timezone.match(value)
205    if tz:
206        value, tz = tz.groups()
207    else:
208        tz = '+0000'
209    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
210    if timezones:
211        if tz.startswith(('+', '-')):
212            if len(tz) < 5:
213                tz += '00'
214            else:
215                tz = tz.replace(':', '')
216        elif tz in timezones:
217            tz = timezones[tz]
218        else:
219            tz = '+0000'
220        value += tz
221        fmt += '%z'
222    return datetime.strptime(value, fmt).timetz()
223
224
225def cast_timestamp(value, connection):
226    """Cast a timestamp value."""
227    if value == '-infinity':
228        return datetime.min
229    if value == 'infinity':
230        return datetime.max
231    value = value.split()
232    if value[-1] == 'BC':
233        return datetime.min
234    fmt = connection.date_format()
235    if fmt.endswith('-%Y') and len(value) > 2:
236        value = value[1:5]
237        if len(value[3]) > 4:
238            return datetime.max
239        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
240            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
241    else:
242        if len(value[0]) > 10:
243            return datetime.max
244        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
245    return datetime.strptime(' '.join(value), ' '.join(fmt))
246
247
248def cast_timestamptz(value, connection):
249    """Cast a timestamptz value."""
250    if value == '-infinity':
251        return datetime.min
252    if value == 'infinity':
253        return datetime.max
254    value = value.split()
255    if value[-1] == 'BC':
256        return datetime.min
257    fmt = connection.date_format()
258    if fmt.endswith('-%Y') and len(value) > 2:
259        value = value[1:]
260        if len(value[3]) > 4:
261            return datetime.max
262        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
263            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
264        value, tz = value[:-1], value[-1]
265    else:
266        if fmt.startswith('%Y-'):
267            tz = _re_timezone.match(value[1])
268            if tz:
269                value[1], tz = tz.groups()
270            else:
271                tz = '+0000'
272        else:
273            value, tz = value[:-1], value[-1]
274        if len(value[0]) > 10:
275            return datetime.max
276        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
277    if timezones:
278        if tz.startswith(('+', '-')):
279            if len(tz) < 5:
280                tz += '00'
281            else:
282                tz = tz.replace(':', '')
283        elif tz in timezones:
284            tz = timezones[tz]
285        else:
286            tz = '+0000'
287        value.append(tz)
288        fmt.append('%z')
289    return datetime.strptime(' '.join(value), ' '.join(fmt))
290
291_re_interval_sql_standard = regex(
292    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
293    '(?:([+-]?[0-9]+)(?!:) ?)?'
294    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
295
296_re_interval_postgres = regex(
297    '(?:([+-]?[0-9]+) ?years? ?)?'
298    '(?:([+-]?[0-9]+) ?mons? ?)?'
299    '(?:([+-]?[0-9]+) ?days? ?)?'
300    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
301
302_re_interval_postgres_verbose = regex(
303    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
304    '(?:([+-]?[0-9]+) ?mons? ?)?'
305    '(?:([+-]?[0-9]+) ?days? ?)?'
306    '(?:([+-]?[0-9]+) ?hours? ?)?'
307    '(?:([+-]?[0-9]+) ?mins? ?)?'
308    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
309
310_re_interval_iso_8601 = regex(
311    'P(?:([+-]?[0-9]+)Y)?'
312    '(?:([+-]?[0-9]+)M)?'
313    '(?:([+-]?[0-9]+)D)?'
314    '(?:T(?:([+-]?[0-9]+)H)?'
315    '(?:([+-]?[0-9]+)M)?'
316    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
317
318
319def cast_interval(value):
320    """Cast an interval value."""
321    # The output format depends on the server setting IntervalStyle, but it's
322    # not necessary to consult this setting to parse it.  It's faster to just
323    # check all possible formats, and there is no ambiguity here.
324    m = _re_interval_iso_8601.match(value)
325    if m:
326        m = [d or '0' for d in m.groups()]
327        secs_ago = m.pop(5) == '-'
328        m = [int(d) for d in m]
329        years, mons, days, hours, mins, secs, usecs = m
330        if secs_ago:
331            secs = -secs
332            usecs = -usecs
333    else:
334        m = _re_interval_postgres_verbose.match(value)
335        if m:
336            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
337            secs_ago = m.pop(5) == '-'
338            m = [-int(d) for d in m] if ago else [int(d) for d in m]
339            years, mons, days, hours, mins, secs, usecs = m
340            if secs_ago:
341                secs = - secs
342                usecs = -usecs
343        else:
344            m = _re_interval_postgres.match(value)
345            if m and any(m.groups()):
346                m = [d or '0' for d in m.groups()]
347                hours_ago = m.pop(3) == '-'
348                m = [int(d) for d in m]
349                years, mons, days, hours, mins, secs, usecs = m
350                if hours_ago:
351                    hours = -hours
352                    mins = -mins
353                    secs = -secs
354                    usecs = -usecs
355            else:
356                m = _re_interval_sql_standard.match(value)
357                if m and any(m.groups()):
358                    m = [d or '0' for d in m.groups()]
359                    years_ago = m.pop(0) == '-'
360                    hours_ago = m.pop(3) == '-'
361                    m = [int(d) for d in m]
362                    years, mons, days, hours, mins, secs, usecs = m
363                    if years_ago:
364                        years = -years
365                        mons = -mons
366                    if hours_ago:
367                        hours = -hours
368                        mins = -mins
369                        secs = -secs
370                        usecs = -usecs
371                else:
372                    raise ValueError('Cannot parse interval: %s' % value)
373    days += 365 * years + 30 * mons
374    return timedelta(days=days, hours=hours, minutes=mins,
375        seconds=secs, microseconds=usecs)
376
377
378class Typecasts(dict):
379    """Dictionary mapping database types to typecast functions.
380
381    The cast functions get passed the string representation of a value in
382    the database which they need to convert to a Python object.  The
383    passed string will never be None since NULL values are already be
384    handled before the cast function is called.
385    """
386
387    # the default cast functions
388    # (str functions are ignored but have been added for faster access)
389    defaults = {'char': str, 'bpchar': str, 'name': str,
390        'text': str, 'varchar': str,
391        'bool': cast_bool, 'bytea': unescape_bytea,
392        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
393        'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode,
394        'float4': float, 'float8': float,
395        'numeric': Decimal, 'money': cast_money,
396        'date': cast_date, 'interval': cast_interval,
397        'time': cast_time, 'timetz': cast_timetz,
398        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
399        'int2vector': cast_int2vector, 'uuid': Uuid,
400        'anyarray': cast_array, 'record': cast_record}
401
402    connection = None  # will be set in local connection specific instances
403
404    def __missing__(self, typ):
405        """Create a cast function if it is not cached.
406
407        Note that this class never raises a KeyError,
408        but returns None when no special cast function exists.
409        """
410        if not isinstance(typ, str):
411            raise TypeError('Invalid type: %s' % typ)
412        cast = self.defaults.get(typ)
413        if cast:
414            # store default for faster access
415            cast = self._add_connection(cast)
416            self[typ] = cast
417        elif typ.startswith('_'):
418            # create array cast
419            base_cast = self[typ[1:]]
420            cast = self.create_array_cast(base_cast)
421            if base_cast:
422                # store only if base type exists
423                self[typ] = cast
424        return cast
425
426    @staticmethod
427    def _needs_connection(func):
428        """Check if a typecast function needs a connection argument."""
429        try:
430            args = get_args(func)
431        except (TypeError, ValueError):
432            return False
433        else:
434            return 'connection' in args[1:]
435
436    def _add_connection(self, cast):
437        """Add a connection argument to the typecast function if necessary."""
438        if not self.connection or not self._needs_connection(cast):
439            return cast
440        return partial(cast, connection=self.connection)
441
442    def get(self, typ, default=None):
443        """Get the typecast function for the given database type."""
444        return self[typ] or default
445
446    def set(self, typ, cast):
447        """Set a typecast function for the specified database type(s)."""
448        if isinstance(typ, basestring):
449            typ = [typ]
450        if cast is None:
451            for t in typ:
452                self.pop(t, None)
453                self.pop('_%s' % t, None)
454        else:
455            if not callable(cast):
456                raise TypeError("Cast parameter must be callable")
457            for t in typ:
458                self[t] = self._add_connection(cast)
459                self.pop('_%s' % t, None)
460
461    def reset(self, typ=None):
462        """Reset the typecasts for the specified type(s) to their defaults.
463
464        When no type is specified, all typecasts will be reset.
465        """
466        defaults = self.defaults
467        if typ is None:
468            self.clear()
469            self.update(defaults)
470        else:
471            if isinstance(typ, basestring):
472                typ = [typ]
473            for t in typ:
474                cast = defaults.get(t)
475                if cast:
476                    self[t] = self._add_connection(cast)
477                    t = '_%s' % t
478                    cast = defaults.get(t)
479                    if cast:
480                        self[t] = self._add_connection(cast)
481                    else:
482                        self.pop(t, None)
483                else:
484                    self.pop(t, None)
485                    self.pop('_%s' % t, None)
486
487    def create_array_cast(self, basecast):
488        """Create an array typecast for the given base cast."""
489        def cast(v):
490            return cast_array(v, basecast)
491        return cast
492
493    def create_record_cast(self, name, fields, casts):
494        """Create a named record typecast for the given fields and casts."""
495        record = namedtuple(name, fields)
496        def cast(v):
497            return record(*cast_record(v, casts))
498        return cast
499
500
501_typecasts = Typecasts()  # this is the global typecast dictionary
502
503
504def get_typecast(typ):
505    """Get the global typecast function for the given database type(s)."""
506    return _typecasts.get(typ)
507
508
509def set_typecast(typ, cast):
510    """Set a global typecast function for the given database type(s).
511
512    Note that connections cache cast functions. To be sure a global change
513    is picked up by a running connection, call con.type_cache.reset_typecast().
514    """
515    _typecasts.set(typ, cast)
516
517
518def reset_typecast(typ=None):
519    """Reset the global typecasts for the given type(s) to their default.
520
521    When no type is specified, all typecasts will be reset.
522
523    Note that connections cache cast functions. To be sure a global change
524    is picked up by a running connection, call con.type_cache.reset_typecast().
525    """
526    _typecasts.reset(typ)
527
528
529class LocalTypecasts(Typecasts):
530    """Map typecasts, including local composite types, to cast functions."""
531
532    defaults = _typecasts
533
534    connection = None  # will be set in a connection specific instance
535
536    def __missing__(self, typ):
537        """Create a cast function if it is not cached."""
538        if typ.startswith('_'):
539            base_cast = self[typ[1:]]
540            cast = self.create_array_cast(base_cast)
541            if base_cast:
542                self[typ] = cast
543        else:
544            cast = self.defaults.get(typ)
545            if cast:
546                cast = self._add_connection(cast)
547                self[typ] = cast
548            else:
549                fields = self.get_fields(typ)
550                if fields:
551                    casts = [self[field.type] for field in fields]
552                    fields = [field.name for field in fields]
553                    cast = self.create_record_cast(typ, fields, casts)
554                    self[typ] = cast
555        return cast
556
557    def get_fields(self, typ):
558        """Return the fields for the given record type.
559
560        This method will be replaced with a method that looks up the fields
561        using the type cache of the connection.
562        """
563        return []
564
565
566class TypeCode(str):
567    """Class representing the type_code used by the DB-API 2.0.
568
569    TypeCode objects are strings equal to the PostgreSQL type name,
570    but carry some additional information.
571    """
572
573    @classmethod
574    def create(cls, oid, name, len, type, category, delim, relid):
575        """Create a type code for a PostgreSQL data type."""
576        self = cls(name)
577        self.oid = oid
578        self.len = len
579        self.type = type
580        self.category = category
581        self.delim = delim
582        self.relid = relid
583        return self
584
585FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
586
587
588class TypeCache(dict):
589    """Cache for database types.
590
591    This cache maps type OIDs and names to TypeCode strings containing
592    important information on the associated database type.
593    """
594
595    def __init__(self, cnx):
596        """Initialize type cache for connection."""
597        super(TypeCache, self).__init__()
598        self._escape_string = cnx.escape_string
599        self._src = cnx.source()
600        self._typecasts = LocalTypecasts()
601        self._typecasts.get_fields = self.get_fields
602        self._typecasts.connection = cnx
603
604    def __missing__(self, key):
605        """Get the type info from the database if it is not cached."""
606        if isinstance(key, int):
607            oid = key
608        else:
609            if '.' not in key and '"' not in key:
610                key = '"%s"' % key
611            oid = "'%s'::regtype" % self._escape_string(key)
612        try:
613            self._src.execute("SELECT oid, typname,"
614                 " typlen, typtype, typcategory, typdelim, typrelid"
615                " FROM pg_type WHERE oid=%s" % oid)
616        except ProgrammingError:
617            res = None
618        else:
619            res = self._src.fetch(1)
620        if not res:
621            raise KeyError('Type %s could not be found' % key)
622        res = res[0]
623        type_code = TypeCode.create(int(res[0]), res[1],
624            int(res[2]), res[3], res[4], res[5], int(res[6]))
625        self[type_code.oid] = self[str(type_code)] = type_code
626        return type_code
627
628    def get(self, key, default=None):
629        """Get the type even if it is not cached."""
630        try:
631            return self[key]
632        except KeyError:
633            return default
634
635    def get_fields(self, typ):
636        """Get the names and types of the fields of composite types."""
637        if not isinstance(typ, TypeCode):
638            typ = self.get(typ)
639            if not typ:
640                return None
641        if not typ.relid:
642            return None  # this type is not composite
643        self._src.execute("SELECT attname, atttypid"
644            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
645            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
646        return [FieldInfo(name, self.get(int(oid)))
647            for name, oid in self._src.fetch(-1)]
648
649    def get_typecast(self, typ):
650        """Get the typecast function for the given database type."""
651        return self._typecasts.get(typ)
652
653    def set_typecast(self, typ, cast):
654        """Set a typecast function for the specified database type(s)."""
655        self._typecasts.set(typ, cast)
656
657    def reset_typecast(self, typ=None):
658        """Reset the typecast function for the specified database type(s)."""
659        self._typecasts.reset(typ)
660
661    def typecast(self, value, typ):
662        """Cast the given value according to the given database type."""
663        if value is None:
664            # for NULL values, no typecast is necessary
665            return None
666        cast = self.get_typecast(typ)
667        if not cast or cast is str:
668            # no typecast is necessary
669            return value
670        return cast(value)
671
672
673class _quotedict(dict):
674    """Dictionary with auto quoting of its items.
675
676    The quote attribute must be set to the desired quote function.
677    """
678
679    def __getitem__(self, key):
680        return self.quote(super(_quotedict, self).__getitem__(key))
681
682
683### Error messages
684
685def _db_error(msg, cls=DatabaseError):
686    """Return DatabaseError with empty sqlstate attribute."""
687    error = cls(msg)
688    error.sqlstate = None
689    return error
690
691
692def _op_error(msg):
693    """Return OperationalError."""
694    return _db_error(msg, OperationalError)
695
696
697### Cursor Object
698
699class Cursor(object):
700    """Cursor object."""
701
702    def __init__(self, dbcnx):
703        """Create a cursor object for the database connection."""
704        self.connection = self._dbcnx = dbcnx
705        self._cnx = dbcnx._cnx
706        self.type_cache = dbcnx.type_cache
707        self._src = self._cnx.source()
708        # the official attribute for describing the result columns
709        self._description = None
710        if self.row_factory is Cursor.row_factory:
711            # the row factory needs to be determined dynamically
712            self.row_factory = None
713        else:
714            self.build_row_factory = None
715        self.rowcount = -1
716        self.arraysize = 1
717        self.lastrowid = None
718
719    def __iter__(self):
720        """Make cursor compatible to the iteration protocol."""
721        return self
722
723    def __enter__(self):
724        """Enter the runtime context for the cursor object."""
725        return self
726
727    def __exit__(self, et, ev, tb):
728        """Exit the runtime context for the cursor object."""
729        self.close()
730
731    def _quote(self, value):
732        """Quote value depending on its type."""
733        if value is None:
734            return 'NULL'
735        if isinstance(value, (Hstore, Json)):
736            value = str(value)
737        if isinstance(value, basestring):
738            if isinstance(value, Binary):
739                value = self._cnx.escape_bytea(value)
740                if bytes is not str:  # Python >= 3.0
741                    value = value.decode('ascii')
742            else:
743                value = self._cnx.escape_string(value)
744            return "'%s'" % value
745        if isinstance(value, float):
746            if isinf(value):
747                return "'-Infinity'" if value < 0 else "'Infinity'"
748            if isnan(value):
749                return "'NaN'"
750            return value
751        if isinstance(value, (int, long, Decimal, Literal)):
752            return value
753        if isinstance(value, datetime):
754            if value.tzinfo:
755                return "'%s'::timestamptz" % value
756            return "'%s'::timestamp" % value
757        if isinstance(value, date):
758            return "'%s'::date" % value
759        if isinstance(value, time):
760            if value.tzinfo:
761                return "'%s'::timetz" % value
762            return "'%s'::time" % value
763        if isinstance(value, timedelta):
764            return "'%s'::interval" % value
765        if isinstance(value, Uuid):
766            return "'%s'::uuid" % value
767        if isinstance(value, list):
768            # Quote value as an ARRAY constructor. This is better than using
769            # an array literal because it carries the information that this is
770            # an array and not a string.  One issue with this syntax is that
771            # you need to add an explicit typecast when passing empty arrays.
772            # The ARRAY keyword is actually only necessary at the top level.
773            if not value:  # exception for empty array
774                return "'{}'"
775            q = self._quote
776            try:
777                return 'ARRAY[%s]' % ','.join(str(q(v)) for v in value)
778            except UnicodeEncodeError:  # Python 2 with non-ascii values
779                return u'ARRAY[%s]' % ','.join(unicode(q(v)) for v in value)
780        if isinstance(value, tuple):
781            # Quote as a ROW constructor.  This is better than using a record
782            # literal because it carries the information that this is a record
783            # and not a string.  We don't use the keyword ROW in order to make
784            # this usable with the IN syntax as well.  It is only necessary
785            # when the records has a single column which is not really useful.
786            q = self._quote
787            try:
788                return '(%s)' % ','.join(str(q(v)) for v in value)
789            except UnicodeEncodeError:  # Python 2 with non-ascii values
790                return u'(%s)' % ','.join(unicode(q(v)) for v in value)
791        try:
792            value = value.__pg_repr__()
793        except AttributeError:
794            raise InterfaceError(
795                'Do not know how to adapt type %s' % type(value))
796        if isinstance(value, (tuple, list)):
797            value = self._quote(value)
798        return value
799
800    def _quoteparams(self, string, parameters):
801        """Quote parameters.
802
803        This function works for both mappings and sequences.
804
805        The function should be used even when there are no parameters,
806        so that we have a consistent behavior regarding percent signs.
807        """
808        if not parameters:
809            try:
810                return string % ()  # unescape literal quotes if possible
811            except (TypeError, ValueError):
812                return string  # silently accept unescaped quotes
813        if isinstance(parameters, dict):
814            parameters = _quotedict(parameters)
815            parameters.quote = self._quote
816        else:
817            parameters = tuple(map(self._quote, parameters))
818        return string % parameters
819
820    def _make_description(self, info):
821        """Make the description tuple for the given field info."""
822        name, typ, size, mod = info[1:]
823        type_code = self.type_cache[typ]
824        if mod > 0:
825            mod -= 4
826        if type_code == 'numeric':
827            precision, scale = mod >> 16, mod & 0xffff
828            size = precision
829        else:
830            if not size:
831                size = type_code.size
832            if size == -1:
833                size = mod
834            precision = scale = None
835        return CursorDescription(name, type_code,
836            None, size, precision, scale, None)
837
838    @property
839    def description(self):
840        """Read-only attribute describing the result columns."""
841        descr = self._description
842        if self._description is True:
843            make = self._make_description
844            descr = [make(info) for info in self._src.listinfo()]
845            self._description = descr
846        return descr
847
848    @property
849    def colnames(self):
850        """Unofficial convenience method for getting the column names."""
851        return [d[0] for d in self.description]
852
853    @property
854    def coltypes(self):
855        """Unofficial convenience method for getting the column types."""
856        return [d[1] for d in self.description]
857
858    def close(self):
859        """Close the cursor object."""
860        self._src.close()
861
862    def execute(self, operation, parameters=None):
863        """Prepare and execute a database operation (query or command)."""
864        # The parameters may also be specified as list of tuples to e.g.
865        # insert multiple rows in a single operation, but this kind of
866        # usage is deprecated.  We make several plausibility checks because
867        # tuples can also be passed with the meaning of ROW constructors.
868        if (parameters and isinstance(parameters, list)
869                and len(parameters) > 1
870                and all(isinstance(p, tuple) for p in parameters)
871                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
872            return self.executemany(operation, parameters)
873        else:
874            # not a list of tuples
875            return self.executemany(operation, [parameters])
876
877    def executemany(self, operation, seq_of_parameters):
878        """Prepare operation and execute it against a parameter sequence."""
879        if not seq_of_parameters:
880            # don't do anything without parameters
881            return
882        self._description = None
883        self.rowcount = -1
884        # first try to execute all queries
885        rowcount = 0
886        sql = "BEGIN"
887        try:
888            if not self._dbcnx._tnx:
889                try:
890                    self._cnx.source().execute(sql)
891                except DatabaseError:
892                    raise  # database provides error message
893                except Exception:
894                    raise _op_error("Can't start transaction")
895                self._dbcnx._tnx = True
896            for parameters in seq_of_parameters:
897                sql = operation
898                sql = self._quoteparams(sql, parameters)
899                rows = self._src.execute(sql)
900                if rows:  # true if not DML
901                    rowcount += rows
902                else:
903                    self.rowcount = -1
904        except DatabaseError:
905            raise  # database provides error message
906        except Error as err:
907            raise _db_error(
908                "Error in '%s': '%s' " % (sql, err), InterfaceError)
909        except Exception as err:
910            raise _op_error("Internal error in '%s': %s" % (sql, err))
911        # then initialize result raw count and description
912        if self._src.resulttype == RESULT_DQL:
913            self._description = True  # fetch on demand
914            self.rowcount = self._src.ntuples
915            self.lastrowid = None
916            if self.build_row_factory:
917                self.row_factory = self.build_row_factory()
918        else:
919            self.rowcount = rowcount
920            self.lastrowid = self._src.oidstatus()
921        # return the cursor object, so you can write statements such as
922        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
923        return self
924
925    def fetchone(self):
926        """Fetch the next row of a query result set."""
927        res = self.fetchmany(1, False)
928        try:
929            return res[0]
930        except IndexError:
931            return None
932
933    def fetchall(self):
934        """Fetch all (remaining) rows of a query result."""
935        return self.fetchmany(-1, False)
936
937    def fetchmany(self, size=None, keep=False):
938        """Fetch the next set of rows of a query result.
939
940        The number of rows to fetch per call is specified by the
941        size parameter. If it is not given, the cursor's arraysize
942        determines the number of rows to be fetched. If you set
943        the keep parameter to true, this is kept as new arraysize.
944        """
945        if size is None:
946            size = self.arraysize
947        if keep:
948            self.arraysize = size
949        try:
950            result = self._src.fetch(size)
951        except DatabaseError:
952            raise
953        except Error as err:
954            raise _db_error(str(err))
955        typecast = self.type_cache.typecast
956        return [self.row_factory([typecast(value, typ)
957            for typ, value in zip(self.coltypes, row)]) for row in result]
958
959    def callproc(self, procname, parameters=None):
960        """Call a stored database procedure with the given name.
961
962        The sequence of parameters must contain one entry for each input
963        argument that the procedure expects. The result of the call is the
964        same as this input sequence; replacement of output and input/output
965        parameters in the return value is currently not supported.
966
967        The procedure may also provide a result set as output. These can be
968        requested through the standard fetch methods of the cursor.
969        """
970        n = parameters and len(parameters) or 0
971        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
972        self.execute(query, parameters)
973        return parameters
974
975    def copy_from(self, stream, table,
976            format=None, sep=None, null=None, size=None, columns=None):
977        """Copy data from an input stream to the specified table.
978
979        The input stream can be a file-like object with a read() method or
980        it can also be an iterable returning a row or multiple rows of input
981        on each iteration.
982
983        The format must be text, csv or binary. The sep option sets the
984        column separator (delimiter) used in the non binary formats.
985        The null option sets the textual representation of NULL in the input.
986
987        The size option sets the size of the buffer used when reading data
988        from file-like objects.
989
990        The copy operation can be restricted to a subset of columns. If no
991        columns are specified, all of them will be copied.
992        """
993        binary_format = format == 'binary'
994        try:
995            read = stream.read
996        except AttributeError:
997            if size:
998                raise ValueError("Size must only be set for file-like objects")
999            if binary_format:
1000                input_type = bytes
1001                type_name = 'byte strings'
1002            else:
1003                input_type = basestring
1004                type_name = 'strings'
1005
1006            if isinstance(stream, basestring):
1007                if not isinstance(stream, input_type):
1008                    raise ValueError("The input must be %s" % type_name)
1009                if not binary_format:
1010                    if isinstance(stream, str):
1011                        if not stream.endswith('\n'):
1012                            stream += '\n'
1013                    else:
1014                        if not stream.endswith(b'\n'):
1015                            stream += b'\n'
1016
1017                def chunks():
1018                    yield stream
1019
1020            elif isinstance(stream, Iterable):
1021
1022                def chunks():
1023                    for chunk in stream:
1024                        if not isinstance(chunk, input_type):
1025                            raise ValueError(
1026                                "Input stream must consist of %s" % type_name)
1027                        if isinstance(chunk, str):
1028                            if not chunk.endswith('\n'):
1029                                chunk += '\n'
1030                        else:
1031                            if not chunk.endswith(b'\n'):
1032                                chunk += b'\n'
1033                        yield chunk
1034
1035            else:
1036                raise TypeError("Need an input stream to copy from")
1037        else:
1038            if size is None:
1039                size = 8192
1040            elif not isinstance(size, int):
1041                raise TypeError("The size option must be an integer")
1042            if size > 0:
1043
1044                def chunks():
1045                    while True:
1046                        buffer = read(size)
1047                        yield buffer
1048                        if not buffer or len(buffer) < size:
1049                            break
1050
1051            else:
1052
1053                def chunks():
1054                    yield read()
1055
1056        if not table or not isinstance(table, basestring):
1057            raise TypeError("Need a table to copy to")
1058        if table.lower().startswith('select'):
1059                raise ValueError("Must specify a table, not a query")
1060        else:
1061            table = '"%s"' % (table,)
1062        operation = ['copy %s' % (table,)]
1063        options = []
1064        params = []
1065        if format is not None:
1066            if not isinstance(format, basestring):
1067                raise TypeError("The frmat option must be be a string")
1068            if format not in ('text', 'csv', 'binary'):
1069                raise ValueError("Invalid format")
1070            options.append('format %s' % (format,))
1071        if sep is not None:
1072            if not isinstance(sep, basestring):
1073                raise TypeError("The sep option must be a string")
1074            if format == 'binary':
1075                raise ValueError(
1076                    "The sep option is not allowed with binary format")
1077            if len(sep) != 1:
1078                raise ValueError(
1079                    "The sep option must be a single one-byte character")
1080            options.append('delimiter %s')
1081            params.append(sep)
1082        if null is not None:
1083            if not isinstance(null, basestring):
1084                raise TypeError("The null option must be a string")
1085            options.append('null %s')
1086            params.append(null)
1087        if columns:
1088            if not isinstance(columns, basestring):
1089                columns = ','.join('"%s"' % (col,) for col in columns)
1090            operation.append('(%s)' % (columns,))
1091        operation.append("from stdin")
1092        if options:
1093            operation.append('(%s)' % ','.join(options))
1094        operation = ' '.join(operation)
1095
1096        putdata = self._src.putdata
1097        self.execute(operation, params)
1098
1099        try:
1100            for chunk in chunks():
1101                putdata(chunk)
1102        except BaseException as error:
1103            self.rowcount = -1
1104            # the following call will re-raise the error
1105            putdata(error)
1106        else:
1107            self.rowcount = putdata(None)
1108
1109        # return the cursor object, so you can chain operations
1110        return self
1111
1112    def copy_to(self, stream, table,
1113            format=None, sep=None, null=None, decode=None, columns=None):
1114        """Copy data from the specified table to an output stream.
1115
1116        The output stream can be a file-like object with a write() method or
1117        it can also be None, in which case the method will return a generator
1118        yielding a row on each iteration.
1119
1120        Output will be returned as byte strings unless you set decode to true.
1121
1122        Note that you can also use a select query instead of the table name.
1123
1124        The format must be text, csv or binary. The sep option sets the
1125        column separator (delimiter) used in the non binary formats.
1126        The null option sets the textual representation of NULL in the output.
1127
1128        The copy operation can be restricted to a subset of columns. If no
1129        columns are specified, all of them will be copied.
1130        """
1131        binary_format = format == 'binary'
1132        if stream is not None:
1133            try:
1134                write = stream.write
1135            except AttributeError:
1136                raise TypeError("Need an output stream to copy to")
1137        if not table or not isinstance(table, basestring):
1138            raise TypeError("Need a table to copy to")
1139        if table.lower().startswith('select'):
1140            if columns:
1141                raise ValueError("Columns must be specified in the query")
1142            table = '(%s)' % (table,)
1143        else:
1144            table = '"%s"' % (table,)
1145        operation = ['copy %s' % (table,)]
1146        options = []
1147        params = []
1148        if format is not None:
1149            if not isinstance(format, basestring):
1150                raise TypeError("The format option must be a string")
1151            if format not in ('text', 'csv', 'binary'):
1152                raise ValueError("Invalid format")
1153            options.append('format %s' % (format,))
1154        if sep is not None:
1155            if not isinstance(sep, basestring):
1156                raise TypeError("The sep option must be a string")
1157            if binary_format:
1158                raise ValueError(
1159                    "The sep option is not allowed with binary format")
1160            if len(sep) != 1:
1161                raise ValueError(
1162                    "The sep option must be a single one-byte character")
1163            options.append('delimiter %s')
1164            params.append(sep)
1165        if null is not None:
1166            if not isinstance(null, basestring):
1167                raise TypeError("The null option must be a string")
1168            options.append('null %s')
1169            params.append(null)
1170        if decode is None:
1171            if format == 'binary':
1172                decode = False
1173            else:
1174                decode = str is unicode
1175        else:
1176            if not isinstance(decode, (int, bool)):
1177                raise TypeError("The decode option must be a boolean")
1178            if decode and binary_format:
1179                raise ValueError(
1180                    "The decode option is not allowed with binary format")
1181        if columns:
1182            if not isinstance(columns, basestring):
1183                columns = ','.join('"%s"' % (col,) for col in columns)
1184            operation.append('(%s)' % (columns,))
1185
1186        operation.append("to stdout")
1187        if options:
1188            operation.append('(%s)' % ','.join(options))
1189        operation = ' '.join(operation)
1190
1191        getdata = self._src.getdata
1192        self.execute(operation, params)
1193
1194        def copy():
1195            self.rowcount = 0
1196            while True:
1197                row = getdata(decode)
1198                if isinstance(row, int):
1199                    if self.rowcount != row:
1200                        self.rowcount = row
1201                    break
1202                self.rowcount += 1
1203                yield row
1204
1205        if stream is None:
1206            # no input stream, return the generator
1207            return copy()
1208
1209        # write the rows to the file-like input stream
1210        for row in copy():
1211            write(row)
1212
1213        # return the cursor object, so you can chain operations
1214        return self
1215
1216    def __next__(self):
1217        """Return the next row (support for the iteration protocol)."""
1218        res = self.fetchone()
1219        if res is None:
1220            raise StopIteration
1221        return res
1222
1223    # Note that since Python 2.6 the iterator protocol uses __next()__
1224    # instead of next(), we keep it only for backward compatibility of pgdb.
1225    next = __next__
1226
1227    @staticmethod
1228    def nextset():
1229        """Not supported."""
1230        raise NotSupportedError("The nextset() method is not supported")
1231
1232    @staticmethod
1233    def setinputsizes(sizes):
1234        """Not supported."""
1235        pass  # unsupported, but silently passed
1236
1237    @staticmethod
1238    def setoutputsize(size, column=0):
1239        """Not supported."""
1240        pass  # unsupported, but silently passed
1241
1242    @staticmethod
1243    def row_factory(row):
1244        """Process rows before they are returned.
1245
1246        You can overwrite this statically with a custom row factory, or
1247        you can build a row factory dynamically with build_row_factory().
1248
1249        For example, you can create a Cursor class that returns rows as
1250        Python dictionaries like this:
1251
1252            class DictCursor(pgdb.Cursor):
1253
1254                def row_factory(self, row):
1255                    return {desc[0]: value
1256                        for desc, value in zip(self.description, row)}
1257
1258            cur = DictCursor(con)  # get one DictCursor instance or
1259            con.cursor_type = DictCursor  # always use DictCursor instances
1260        """
1261        raise NotImplementedError
1262
1263    def build_row_factory(self):
1264        """Build a row factory based on the current description.
1265
1266        This implementation builds a row factory for creating named tuples.
1267        You can overwrite this method if you want to dynamically create
1268        different row factories whenever the column description changes.
1269        """
1270        colnames = self.colnames
1271        if colnames:
1272            try:
1273                try:
1274                    return namedtuple('Row', colnames, rename=True)._make
1275                except TypeError:  # Python 2.6 and 3.0 do not support rename
1276                    colnames = [v if v.isalnum() else 'column_%d' % n
1277                             for n, v in enumerate(colnames)]
1278                    return namedtuple('Row', colnames)._make
1279            except ValueError:  # there is still a problem with the field names
1280                colnames = ['column_%d' % n for n in range(len(colnames))]
1281                return namedtuple('Row', colnames)._make
1282
1283
1284CursorDescription = namedtuple('CursorDescription',
1285    ['name', 'type_code', 'display_size', 'internal_size',
1286     'precision', 'scale', 'null_ok'])
1287
1288
1289### Connection Objects
1290
1291class Connection(object):
1292    """Connection object."""
1293
1294    # expose the exceptions as attributes on the connection object
1295    Error = Error
1296    Warning = Warning
1297    InterfaceError = InterfaceError
1298    DatabaseError = DatabaseError
1299    InternalError = InternalError
1300    OperationalError = OperationalError
1301    ProgrammingError = ProgrammingError
1302    IntegrityError = IntegrityError
1303    DataError = DataError
1304    NotSupportedError = NotSupportedError
1305
1306    def __init__(self, cnx):
1307        """Create a database connection object."""
1308        self._cnx = cnx  # connection
1309        self._tnx = False  # transaction state
1310        self.type_cache = TypeCache(cnx)
1311        self.cursor_type = Cursor
1312        try:
1313            self._cnx.source()
1314        except Exception:
1315            raise _op_error("Invalid connection")
1316
1317    def __enter__(self):
1318        """Enter the runtime context for the connection object.
1319
1320        The runtime context can be used for running transactions.
1321        """
1322        return self
1323
1324    def __exit__(self, et, ev, tb):
1325        """Exit the runtime context for the connection object.
1326
1327        This does not close the connection, but it ends a transaction.
1328        """
1329        if et is None and ev is None and tb is None:
1330            self.commit()
1331        else:
1332            self.rollback()
1333
1334    def close(self):
1335        """Close the connection object."""
1336        if self._cnx:
1337            if self._tnx:
1338                try:
1339                    self.rollback()
1340                except DatabaseError:
1341                    pass
1342            self._cnx.close()
1343            self._cnx = None
1344        else:
1345            raise _op_error("Connection has been closed")
1346
1347    @property
1348    def closed(self):
1349        """Check whether the connection has been closed or is broken."""
1350        try:
1351            return not self._cnx or self._cnx.status != 1
1352        except TypeError:
1353            return True
1354
1355    def commit(self):
1356        """Commit any pending transaction to the database."""
1357        if self._cnx:
1358            if self._tnx:
1359                self._tnx = False
1360                try:
1361                    self._cnx.source().execute("COMMIT")
1362                except DatabaseError:
1363                    raise
1364                except Exception:
1365                    raise _op_error("Can't commit")
1366        else:
1367            raise _op_error("Connection has been closed")
1368
1369    def rollback(self):
1370        """Roll back to the start of any pending transaction."""
1371        if self._cnx:
1372            if self._tnx:
1373                self._tnx = False
1374                try:
1375                    self._cnx.source().execute("ROLLBACK")
1376                except DatabaseError:
1377                    raise
1378                except Exception:
1379                    raise _op_error("Can't rollback")
1380        else:
1381            raise _op_error("Connection has been closed")
1382
1383    def cursor(self):
1384        """Return a new cursor object using the connection."""
1385        if self._cnx:
1386            try:
1387                return self.cursor_type(self)
1388            except Exception:
1389                raise _op_error("Invalid connection")
1390        else:
1391            raise _op_error("Connection has been closed")
1392
1393    if shortcutmethods:  # otherwise do not implement and document this
1394
1395        def execute(self, operation, params=None):
1396            """Shortcut method to run an operation on an implicit cursor."""
1397            cursor = self.cursor()
1398            cursor.execute(operation, params)
1399            return cursor
1400
1401        def executemany(self, operation, param_seq):
1402            """Shortcut method to run an operation against a sequence."""
1403            cursor = self.cursor()
1404            cursor.executemany(operation, param_seq)
1405            return cursor
1406
1407
1408### Module Interface
1409
1410_connect = connect
1411
1412def connect(dsn=None,
1413        user=None, password=None,
1414        host=None, database=None):
1415    """Connect to a database."""
1416    # first get params from DSN
1417    dbport = -1
1418    dbhost = ""
1419    dbbase = ""
1420    dbuser = ""
1421    dbpasswd = ""
1422    dbopt = ""
1423    try:
1424        params = dsn.split(":")
1425        dbhost = params[0]
1426        dbbase = params[1]
1427        dbuser = params[2]
1428        dbpasswd = params[3]
1429        dbopt = params[4]
1430    except (AttributeError, IndexError, TypeError):
1431        pass
1432
1433    # override if necessary
1434    if user is not None:
1435        dbuser = user
1436    if password is not None:
1437        dbpasswd = password
1438    if database is not None:
1439        dbbase = database
1440    if host is not None:
1441        try:
1442            params = host.split(":")
1443            dbhost = params[0]
1444            dbport = int(params[1])
1445        except (AttributeError, IndexError, TypeError, ValueError):
1446            pass
1447
1448    # empty host is localhost
1449    if dbhost == "":
1450        dbhost = None
1451    if dbuser == "":
1452        dbuser = None
1453
1454    # open the connection
1455    cnx = _connect(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
1456    return Connection(cnx)
1457
1458
1459### Types Handling
1460
1461class Type(frozenset):
1462    """Type class for a couple of PostgreSQL data types.
1463
1464    PostgreSQL is object-oriented: types are dynamic.
1465    We must thus use type names as internal type codes.
1466    """
1467
1468    def __new__(cls, values):
1469        if isinstance(values, basestring):
1470            values = values.split()
1471        return super(Type, cls).__new__(cls, values)
1472
1473    def __eq__(self, other):
1474        if isinstance(other, basestring):
1475            if other.startswith('_'):
1476                other = other[1:]
1477            return other in self
1478        else:
1479            return super(Type, self).__eq__(other)
1480
1481    def __ne__(self, other):
1482        if isinstance(other, basestring):
1483            if other.startswith('_'):
1484                other = other[1:]
1485            return other not in self
1486        else:
1487            return super(Type, self).__ne__(other)
1488
1489
1490class ArrayType:
1491    """Type class for PostgreSQL array types."""
1492
1493    def __eq__(self, other):
1494        if isinstance(other, basestring):
1495            return other.startswith('_')
1496        else:
1497            return isinstance(other, ArrayType)
1498
1499    def __ne__(self, other):
1500        if isinstance(other, basestring):
1501            return not other.startswith('_')
1502        else:
1503            return not isinstance(other, ArrayType)
1504
1505
1506class RecordType:
1507    """Type class for PostgreSQL record types."""
1508
1509    def __eq__(self, other):
1510        if isinstance(other, TypeCode):
1511            return other.type == 'c'
1512        elif isinstance(other, basestring):
1513            return other == 'record'
1514        else:
1515            return isinstance(other, RecordType)
1516
1517    def __ne__(self, other):
1518        if isinstance(other, TypeCode):
1519            return other.type != 'c'
1520        elif isinstance(other, basestring):
1521            return other != 'record'
1522        else:
1523            return not isinstance(other, RecordType)
1524
1525
1526# Mandatory type objects defined by DB-API 2 specs:
1527
1528STRING = Type('char bpchar name text varchar')
1529BINARY = Type('bytea')
1530NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1531DATETIME = Type('date time timetz timestamp timestamptz interval'
1532    ' abstime reltime')  # these are very old
1533ROWID = Type('oid')
1534
1535
1536# Additional type objects (more specific):
1537
1538BOOL = Type('bool')
1539SMALLINT = Type('int2')
1540INTEGER = Type('int2 int4 int8 serial')
1541LONG = Type('int8')
1542FLOAT = Type('float4 float8')
1543NUMERIC = Type('numeric')
1544MONEY = Type('money')
1545DATE = Type('date')
1546TIME = Type('time timetz')
1547TIMESTAMP = Type('timestamp timestamptz')
1548INTERVAL = Type('interval')
1549UUID = Type('uuid')
1550HSTORE = Type('hstore')
1551JSON = Type('json jsonb')
1552
1553# Type object for arrays (also equate to their base types):
1554
1555ARRAY = ArrayType()
1556
1557# Type object for records (encompassing all composite types):
1558
1559RECORD = RecordType()
1560
1561
1562# Mandatory type helpers defined by DB-API 2 specs:
1563
1564def Date(year, month, day):
1565    """Construct an object holding a date value."""
1566    return date(year, month, day)
1567
1568
1569def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None):
1570    """Construct an object holding a time value."""
1571    return time(hour, minute, second, microsecond, tzinfo)
1572
1573
1574def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0,
1575        tzinfo=None):
1576    """Construct an object holding a time stamp value."""
1577    return datetime(year, month, day, hour, minute, second, microsecond, tzinfo)
1578
1579
1580def DateFromTicks(ticks):
1581    """Construct an object holding a date value from the given ticks value."""
1582    return Date(*localtime(ticks)[:3])
1583
1584
1585def TimeFromTicks(ticks):
1586    """Construct an object holding a time value from the given ticks value."""
1587    return Time(*localtime(ticks)[3:6])
1588
1589
1590def TimestampFromTicks(ticks):
1591    """Construct an object holding a time stamp from the given ticks value."""
1592    return Timestamp(*localtime(ticks)[:6])
1593
1594
1595class Binary(bytes):
1596    """Construct an object capable of holding a binary (long) string value."""
1597
1598
1599# Additional type helpers for PyGreSQL:
1600
1601def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0):
1602    """Construct an object holding a time inverval value."""
1603    return timedelta(days, hours=hours, minutes=minutes, seconds=seconds,
1604        microseconds=microseconds)
1605
1606
1607Uuid = Uuid  # Construct an object holding a UUID value
1608
1609
1610class Hstore(dict):
1611    """Wrapper class for marking hstore values."""
1612
1613    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
1614    _re_escape = regex(r'(["\\])')
1615
1616    @classmethod
1617    def _quote(cls, s):
1618        if s is None:
1619            return 'NULL'
1620        if not s:
1621            return '""'
1622        quote = cls._re_quote.search(s)
1623        s = cls._re_escape.sub(r'\\\1', s)
1624        if quote:
1625            s = '"%s"' % s
1626        return s
1627
1628    def __str__(self):
1629        q = self._quote
1630        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
1631
1632
1633class Json:
1634    """Construct a wrapper for holding an object serializable to JSON."""
1635
1636    def __init__(self, obj, encode=None):
1637        self.obj = obj
1638        self.encode = encode or jsonencode
1639
1640    def __str__(self):
1641        obj = self.obj
1642        if isinstance(obj, basestring):
1643            return obj
1644        return self.encode(obj)
1645
1646
1647class Literal:
1648    """Construct a wrapper for holding a literal SQL string."""
1649
1650    def __init__(self, sql):
1651        self.sql = sql
1652
1653    def __str__(self):
1654        return self.sql
1655
1656    __pg_repr__ = __str__
1657
1658# If run as script, print some information:
1659
1660if __name__ == '__main__':
1661    print('PyGreSQL version', version)
1662    print('')
1663    print(__doc__)
Note: See TracBrowser for help on using the repository browser.