source: trunk/pgdb.py @ 821

Last change on this file since 821 was 821, checked in by cito, 3 years ago

Support the uuid data type

This is often useful and also supported by SQLAlchemy

  • Property svn:keywords set to Id
File size: 52.8 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 821 2016-02-05 16:13:54Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from uuid import UUID
74from math import isnan, isinf
75from collections import namedtuple
76from functools import partial
77from re import compile as regex
78from json import loads as jsondecode, dumps as jsonencode
79
80try:
81    long
82except NameError:  # Python >= 3.0
83    long = int
84
85try:
86    unicode
87except NameError:  # Python >= 3.0
88    unicode = str
89
90try:
91    basestring
92except NameError:  # Python >= 3.0
93    basestring = (str, bytes)
94
95from collections import Iterable
96
97
98### Module Constants
99
100# compliant with DB API 2.0
101apilevel = '2.0'
102
103# module may be shared, but not connections
104threadsafety = 1
105
106# this module use extended python format codes
107paramstyle = 'pyformat'
108
109# shortcut methods have been excluded from DB API 2 and
110# are not recommended by the DB SIG, but they can be handy
111shortcutmethods = 1
112
113
114### Internal Type Handling
115
116try:
117    from inspect import signature
118except ImportError:  # Python < 3.3
119    from inspect import getargspec
120
121    def get_args(func):
122        return getargspec(func).args
123else:
124
125    def get_args(func):
126        return list(signature(func).parameters)
127
128try:
129    if datetime.strptime('+0100', '%z') is None:
130        raise ValueError
131except ValueError:  # Python < 3.2
132    timezones = None
133else:
134    # time zones used in Postgres timestamptz output
135    timezones = dict(CET='+0100', EET='+0200', EST='-0500',
136        GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
137        UCT='+0000', UTC='+0000', WET='+0000')
138
139
140def decimal_type(decimal_type=None):
141    """Get or set global type to be used for decimal values.
142
143    Note that connections cache cast functions. To be sure a global change
144    is picked up by a running connection, call con.type_cache.reset_typecast().
145    """
146    global Decimal
147    if decimal_type is not None:
148        Decimal = decimal_type
149        set_typecast('numeric', decimal_type)
150    return Decimal
151
152
153def cast_bool(value):
154    """Cast boolean value in database format to bool."""
155    if value:
156        return value[0] in ('t', 'T')
157
158
159def cast_money(value):
160    """Cast money value in database format to Decimal."""
161    if value:
162        value = value.replace('(', '-')
163        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
164
165
166def cast_int2vector(value):
167    """Cast an int2vector value."""
168    return [int(v) for v in value.split()]
169
170
171def cast_date(value, connection):
172    """Cast a date value."""
173    # The output format depends on the server setting DateStyle.  The default
174    # setting ISO and the setting for German are actually unambiguous.  The
175    # order of days and months in the other two settings is however ambiguous,
176    # so at least here we need to consult the setting to properly parse values.
177    if value == '-infinity':
178        return date.min
179    if value == 'infinity':
180        return date.max
181    value = value.split()
182    if value[-1] == 'BC':
183        return date.min
184    value = value[0]
185    if len(value) > 10:
186        return date.max
187    fmt = connection.date_format()
188    return datetime.strptime(value, fmt).date()
189
190
191def cast_time(value):
192    """Cast a time value."""
193    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
194    return datetime.strptime(value, fmt).time()
195
196
197_re_timezone = regex('(.*)([+-].*)')
198
199
200def cast_timetz(value):
201    """Cast a timetz value."""
202    tz = _re_timezone.match(value)
203    if tz:
204        value, tz = tz.groups()
205    else:
206        tz = '+0000'
207    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
208    if timezones:
209        if tz.startswith(('+', '-')):
210            if len(tz) < 5:
211                tz += '00'
212            else:
213                tz = tz.replace(':', '')
214        elif tz in timezones:
215            tz = timezones[tz]
216        else:
217            tz = '+0000'
218        value += tz
219        fmt += '%z'
220    return datetime.strptime(value, fmt).timetz()
221
222
223def cast_timestamp(value, connection):
224    """Cast a timestamp value."""
225    if value == '-infinity':
226        return datetime.min
227    if value == 'infinity':
228        return datetime.max
229    value = value.split()
230    if value[-1] == 'BC':
231        return datetime.min
232    fmt = connection.date_format()
233    if fmt.endswith('-%Y') and len(value) > 2:
234        value = value[1:5]
235        if len(value[3]) > 4:
236            return datetime.max
237        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
238            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
239    else:
240        if len(value[0]) > 10:
241            return datetime.max
242        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
243    return datetime.strptime(' '.join(value), ' '.join(fmt))
244
245
246def cast_timestamptz(value, connection):
247    """Cast a timestamptz value."""
248    if value == '-infinity':
249        return datetime.min
250    if value == 'infinity':
251        return datetime.max
252    value = value.split()
253    if value[-1] == 'BC':
254        return datetime.min
255    fmt = connection.date_format()
256    if fmt.endswith('-%Y') and len(value) > 2:
257        value = value[1:]
258        if len(value[3]) > 4:
259            return datetime.max
260        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
261            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
262        value, tz = value[:-1], value[-1]
263    else:
264        if fmt.startswith('%Y-'):
265            tz = _re_timezone.match(value[1])
266            if tz:
267                value[1], tz = tz.groups()
268            else:
269                tz = '+0000'
270        else:
271            value, tz = value[:-1], value[-1]
272        if len(value[0]) > 10:
273            return datetime.max
274        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
275    if timezones:
276        if tz.startswith(('+', '-')):
277            if len(tz) < 5:
278                tz += '00'
279            else:
280                tz = tz.replace(':', '')
281        elif tz in timezones:
282            tz = timezones[tz]
283        else:
284            tz = '+0000'
285        value.append(tz)
286        fmt.append('%z')
287    return datetime.strptime(' '.join(value), ' '.join(fmt))
288
289_re_interval_sql_standard = regex(
290    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
291    '(?:([+-]?[0-9]+)(?!:) ?)?'
292    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
293
294_re_interval_postgres = regex(
295    '(?:([+-]?[0-9]+) ?years? ?)?'
296    '(?:([+-]?[0-9]+) ?mons? ?)?'
297    '(?:([+-]?[0-9]+) ?days? ?)?'
298    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
299
300_re_interval_postgres_verbose = regex(
301    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
302    '(?:([+-]?[0-9]+) ?mons? ?)?'
303    '(?:([+-]?[0-9]+) ?days? ?)?'
304    '(?:([+-]?[0-9]+) ?hours? ?)?'
305    '(?:([+-]?[0-9]+) ?mins? ?)?'
306    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
307
308_re_interval_iso_8601 = regex(
309    'P(?:([+-]?[0-9]+)Y)?'
310    '(?:([+-]?[0-9]+)M)?'
311    '(?:([+-]?[0-9]+)D)?'
312    '(?:T(?:([+-]?[0-9]+)H)?'
313    '(?:([+-]?[0-9]+)M)?'
314    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
315
316
317def cast_interval(value):
318    """Cast an interval value."""
319    # The output format depends on the server setting IntervalStyle, but it's
320    # not necessary to consult this setting to parse it.  It's faster to just
321    # check all possible formats, and there is no ambiguity here.
322    m = _re_interval_iso_8601.match(value)
323    if m:
324        m = [d or '0' for d in m.groups()]
325        secs_ago = m.pop(5) == '-'
326        m = [int(d) for d in m]
327        years, mons, days, hours, mins, secs, usecs = m
328        if secs_ago:
329            secs = -secs
330            usecs = -usecs
331    else:
332        m = _re_interval_postgres_verbose.match(value)
333        if m:
334            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
335            secs_ago = m.pop(5) == '-'
336            m = [-int(d) for d in m] if ago else [int(d) for d in m]
337            years, mons, days, hours, mins, secs, usecs = m
338            if secs_ago:
339                secs = - secs
340                usecs = -usecs
341        else:
342            m = _re_interval_postgres.match(value)
343            if m and any(m.groups()):
344                m = [d or '0' for d in m.groups()]
345                hours_ago = m.pop(3) == '-'
346                m = [int(d) for d in m]
347                years, mons, days, hours, mins, secs, usecs = m
348                if hours_ago:
349                    hours = -hours
350                    mins = -mins
351                    secs = -secs
352                    usecs = -usecs
353            else:
354                m = _re_interval_sql_standard.match(value)
355                if m and any(m.groups()):
356                    m = [d or '0' for d in m.groups()]
357                    years_ago = m.pop(0) == '-'
358                    hours_ago = m.pop(3) == '-'
359                    m = [int(d) for d in m]
360                    years, mons, days, hours, mins, secs, usecs = m
361                    if years_ago:
362                        years = -years
363                        mons = -mons
364                    if hours_ago:
365                        hours = -hours
366                        mins = -mins
367                        secs = -secs
368                        usecs = -usecs
369                else:
370                    raise ValueError('Cannot parse interval: %s' % value)
371    days += 365 * years + 30 * mons
372    return timedelta(days=days, hours=hours, minutes=mins,
373        seconds=secs, microseconds=usecs)
374
375
376class Typecasts(dict):
377    """Dictionary mapping database types to typecast functions.
378
379    The cast functions get passed the string representation of a value in
380    the database which they need to convert to a Python object.  The
381    passed string will never be None since NULL values are already be
382    handled before the cast function is called.
383    """
384
385    # the default cast functions
386    # (str functions are ignored but have been added for faster access)
387    defaults = {'char': str, 'bpchar': str, 'name': str,
388        'text': str, 'varchar': str,
389        'bool': cast_bool, 'bytea': unescape_bytea,
390        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
391        'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode,
392        'float4': float, 'float8': float,
393        'numeric': Decimal, 'money': cast_money,
394        'date': cast_date, 'interval': cast_interval,
395        'time': cast_time, 'timetz': cast_timetz,
396        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
397        'int2vector': cast_int2vector, 'uuid': UUID,
398        'anyarray': cast_array, 'record': cast_record}
399
400    connection = None  # will be set in local connection specific instances
401
402    def __missing__(self, typ):
403        """Create a cast function if it is not cached.
404
405        Note that this class never raises a KeyError,
406        but returns None when no special cast function exists.
407        """
408        if not isinstance(typ, str):
409            raise TypeError('Invalid type: %s' % typ)
410        cast = self.defaults.get(typ)
411        if cast:
412            # store default for faster access
413            cast = self._add_connection(cast)
414            self[typ] = cast
415        elif typ.startswith('_'):
416            # create array cast
417            base_cast = self[typ[1:]]
418            cast = self.create_array_cast(base_cast)
419            if base_cast:
420                # store only if base type exists
421                self[typ] = cast
422        return cast
423
424    @staticmethod
425    def _needs_connection(func):
426        """Check if a typecast function needs a connection argument."""
427        try:
428            args = get_args(func)
429        except (TypeError, ValueError):
430            return False
431        else:
432            return 'connection' in args[1:]
433
434    def _add_connection(self, cast):
435        """Add a connection argument to the typecast function if necessary."""
436        if not self.connection or not self._needs_connection(cast):
437            return cast
438        return partial(cast, connection=self.connection)
439
440    def get(self, typ, default=None):
441        """Get the typecast function for the given database type."""
442        return self[typ] or default
443
444    def set(self, typ, cast):
445        """Set a typecast function for the specified database type(s)."""
446        if isinstance(typ, basestring):
447            typ = [typ]
448        if cast is None:
449            for t in typ:
450                self.pop(t, None)
451                self.pop('_%s' % t, None)
452        else:
453            if not callable(cast):
454                raise TypeError("Cast parameter must be callable")
455            for t in typ:
456                self[t] = self._add_connection(cast)
457                self.pop('_%s' % t, None)
458
459    def reset(self, typ=None):
460        """Reset the typecasts for the specified type(s) to their defaults.
461
462        When no type is specified, all typecasts will be reset.
463        """
464        defaults = self.defaults
465        if typ is None:
466            self.clear()
467            self.update(defaults)
468        else:
469            if isinstance(typ, basestring):
470                typ = [typ]
471            for t in typ:
472                cast = defaults.get(t)
473                if cast:
474                    self[t] = self._add_connection(cast)
475                    t = '_%s' % t
476                    cast = defaults.get(t)
477                    if cast:
478                        self[t] = self._add_connection(cast)
479                    else:
480                        self.pop(t, None)
481                else:
482                    self.pop(t, None)
483                    self.pop('_%s' % t, None)
484
485    def create_array_cast(self, basecast):
486        """Create an array typecast for the given base cast."""
487        def cast(v):
488            return cast_array(v, basecast)
489        return cast
490
491    def create_record_cast(self, name, fields, casts):
492        """Create a named record typecast for the given fields and casts."""
493        record = namedtuple(name, fields)
494        def cast(v):
495            return record(*cast_record(v, casts))
496        return cast
497
498
499_typecasts = Typecasts()  # this is the global typecast dictionary
500
501
502def get_typecast(typ):
503    """Get the global typecast function for the given database type(s)."""
504    return _typecasts.get(typ)
505
506
507def set_typecast(typ, cast):
508    """Set a global typecast function for the given database type(s).
509
510    Note that connections cache cast functions. To be sure a global change
511    is picked up by a running connection, call con.type_cache.reset_typecast().
512    """
513    _typecasts.set(typ, cast)
514
515
516def reset_typecast(typ=None):
517    """Reset the global typecasts for the given type(s) to their default.
518
519    When no type is specified, all typecasts will be reset.
520
521    Note that connections cache cast functions. To be sure a global change
522    is picked up by a running connection, call con.type_cache.reset_typecast().
523    """
524    _typecasts.reset(typ)
525
526
527class LocalTypecasts(Typecasts):
528    """Map typecasts, including local composite types, to cast functions."""
529
530    defaults = _typecasts
531
532    connection = None  # will be set in a connection specific instance
533
534    def __missing__(self, typ):
535        """Create a cast function if it is not cached."""
536        if typ.startswith('_'):
537            base_cast = self[typ[1:]]
538            cast = self.create_array_cast(base_cast)
539            if base_cast:
540                self[typ] = cast
541        else:
542            cast = self.defaults.get(typ)
543            if cast:
544                cast = self._add_connection(cast)
545                self[typ] = cast
546            else:
547                fields = self.get_fields(typ)
548                if fields:
549                    casts = [self[field.type] for field in fields]
550                    fields = [field.name for field in fields]
551                    cast = self.create_record_cast(typ, fields, casts)
552                    self[typ] = cast
553        return cast
554
555    def get_fields(self, typ):
556        """Return the fields for the given record type.
557
558        This method will be replaced with a method that looks up the fields
559        using the type cache of the connection.
560        """
561        return []
562
563
564class TypeCode(str):
565    """Class representing the type_code used by the DB-API 2.0.
566
567    TypeCode objects are strings equal to the PostgreSQL type name,
568    but carry some additional information.
569    """
570
571    @classmethod
572    def create(cls, oid, name, len, type, category, delim, relid):
573        """Create a type code for a PostgreSQL data type."""
574        self = cls(name)
575        self.oid = oid
576        self.len = len
577        self.type = type
578        self.category = category
579        self.delim = delim
580        self.relid = relid
581        return self
582
583FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
584
585
586class TypeCache(dict):
587    """Cache for database types.
588
589    This cache maps type OIDs and names to TypeCode strings containing
590    important information on the associated database type.
591    """
592
593    def __init__(self, cnx):
594        """Initialize type cache for connection."""
595        super(TypeCache, self).__init__()
596        self._escape_string = cnx.escape_string
597        self._src = cnx.source()
598        self._typecasts = LocalTypecasts()
599        self._typecasts.get_fields = self.get_fields
600        self._typecasts.connection = cnx
601
602    def __missing__(self, key):
603        """Get the type info from the database if it is not cached."""
604        if isinstance(key, int):
605            oid = key
606        else:
607            if '.' not in key and '"' not in key:
608                key = '"%s"' % key
609            oid = "'%s'::regtype" % self._escape_string(key)
610        try:
611            self._src.execute("SELECT oid, typname,"
612                 " typlen, typtype, typcategory, typdelim, typrelid"
613                " FROM pg_type WHERE oid=%s" % oid)
614        except ProgrammingError:
615            res = None
616        else:
617            res = self._src.fetch(1)
618        if not res:
619            raise KeyError('Type %s could not be found' % key)
620        res = res[0]
621        type_code = TypeCode.create(int(res[0]), res[1],
622            int(res[2]), res[3], res[4], res[5], int(res[6]))
623        self[type_code.oid] = self[str(type_code)] = type_code
624        return type_code
625
626    def get(self, key, default=None):
627        """Get the type even if it is not cached."""
628        try:
629            return self[key]
630        except KeyError:
631            return default
632
633    def get_fields(self, typ):
634        """Get the names and types of the fields of composite types."""
635        if not isinstance(typ, TypeCode):
636            typ = self.get(typ)
637            if not typ:
638                return None
639        if not typ.relid:
640            return None  # this type is not composite
641        self._src.execute("SELECT attname, atttypid"
642            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
643            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
644        return [FieldInfo(name, self.get(int(oid)))
645            for name, oid in self._src.fetch(-1)]
646
647    def get_typecast(self, typ):
648        """Get the typecast function for the given database type."""
649        return self._typecasts.get(typ)
650
651    def set_typecast(self, typ, cast):
652        """Set a typecast function for the specified database type(s)."""
653        self._typecasts.set(typ, cast)
654
655    def reset_typecast(self, typ=None):
656        """Reset the typecast function for the specified database type(s)."""
657        self._typecasts.reset(typ)
658
659    def typecast(self, value, typ):
660        """Cast the given value according to the given database type."""
661        if value is None:
662            # for NULL values, no typecast is necessary
663            return None
664        cast = self.get_typecast(typ)
665        if not cast or cast is str:
666            # no typecast is necessary
667            return value
668        return cast(value)
669
670
671class _quotedict(dict):
672    """Dictionary with auto quoting of its items.
673
674    The quote attribute must be set to the desired quote function.
675    """
676
677    def __getitem__(self, key):
678        return self.quote(super(_quotedict, self).__getitem__(key))
679
680
681### Error messages
682
683def _db_error(msg, cls=DatabaseError):
684    """Return DatabaseError with empty sqlstate attribute."""
685    error = cls(msg)
686    error.sqlstate = None
687    return error
688
689
690def _op_error(msg):
691    """Return OperationalError."""
692    return _db_error(msg, OperationalError)
693
694
695### Cursor Object
696
697class Cursor(object):
698    """Cursor object."""
699
700    def __init__(self, dbcnx):
701        """Create a cursor object for the database connection."""
702        self.connection = self._dbcnx = dbcnx
703        self._cnx = dbcnx._cnx
704        self.type_cache = dbcnx.type_cache
705        self._src = self._cnx.source()
706        # the official attribute for describing the result columns
707        self._description = None
708        if self.row_factory is Cursor.row_factory:
709            # the row factory needs to be determined dynamically
710            self.row_factory = None
711        else:
712            self.build_row_factory = None
713        self.rowcount = -1
714        self.arraysize = 1
715        self.lastrowid = None
716
717    def __iter__(self):
718        """Make cursor compatible to the iteration protocol."""
719        return self
720
721    def __enter__(self):
722        """Enter the runtime context for the cursor object."""
723        return self
724
725    def __exit__(self, et, ev, tb):
726        """Exit the runtime context for the cursor object."""
727        self.close()
728
729    def _quote(self, value):
730        """Quote value depending on its type."""
731        if value is None:
732            return 'NULL'
733        if isinstance(value,
734                (datetime, date, time, timedelta, Hstore, Json, UUID)):
735            value = str(value)
736        if isinstance(value, basestring):
737            if isinstance(value, Binary):
738                value = self._cnx.escape_bytea(value)
739                if bytes is not str:  # Python >= 3.0
740                    value = value.decode('ascii')
741            else:
742                value = self._cnx.escape_string(value)
743            return "'%s'" % value
744        if isinstance(value, float):
745            if isinf(value):
746                return "'-Infinity'" if value < 0 else "'Infinity'"
747            if isnan(value):
748                return "'NaN'"
749            return value
750        if isinstance(value, (int, long, Decimal, Literal)):
751            return value
752        if isinstance(value, list):
753            # Quote value as an ARRAY constructor. This is better than using
754            # an array literal because it carries the information that this is
755            # an array and not a string.  One issue with this syntax is that
756            # you need to add an explicit typecast when passing empty arrays.
757            # The ARRAY keyword is actually only necessary at the top level.
758            q = self._quote
759            return 'ARRAY[%s]' % ','.join(str(q(v)) for v in value)
760        if isinstance(value, tuple):
761            # Quote as a ROW constructor.  This is better than using a record
762            # literal because it carries the information that this is a record
763            # and not a string.  We don't use the keyword ROW in order to make
764            # this usable with the IN syntax as well.  It is only necessary
765            # when the records has a single column which is not really useful.
766            q = self._quote
767            return '(%s)' % ','.join(str(q(v)) for v in value)
768        try:
769            value = value.__pg_repr__()
770        except AttributeError:
771            raise InterfaceError(
772                'Do not know how to adapt type %s' % type(value))
773        if isinstance(value, (tuple, list)):
774            value = self._quote(value)
775        return value
776
777    def _quoteparams(self, string, parameters):
778        """Quote parameters.
779
780        This function works for both mappings and sequences.
781        """
782        if isinstance(parameters, dict):
783            parameters = _quotedict(parameters)
784            parameters.quote = self._quote
785        else:
786            parameters = tuple(map(self._quote, parameters))
787        return string % parameters
788
789    def _make_description(self, info):
790        """Make the description tuple for the given field info."""
791        name, typ, size, mod = info[1:]
792        type_code = self.type_cache[typ]
793        if mod > 0:
794            mod -= 4
795        if type_code == 'numeric':
796            precision, scale = mod >> 16, mod & 0xffff
797            size = precision
798        else:
799            if not size:
800                size = type_code.size
801            if size == -1:
802                size = mod
803            precision = scale = None
804        return CursorDescription(name, type_code,
805            None, size, precision, scale, None)
806
807    @property
808    def description(self):
809        """Read-only attribute describing the result columns."""
810        descr = self._description
811        if self._description is True:
812            make = self._make_description
813            descr = [make(info) for info in self._src.listinfo()]
814            self._description = descr
815        return descr
816
817    @property
818    def colnames(self):
819        """Unofficial convenience method for getting the column names."""
820        return [d[0] for d in self.description]
821
822    @property
823    def coltypes(self):
824        """Unofficial convenience method for getting the column types."""
825        return [d[1] for d in self.description]
826
827    def close(self):
828        """Close the cursor object."""
829        self._src.close()
830        self._description = None
831        self.rowcount = -1
832        self.lastrowid = None
833
834    def execute(self, operation, parameters=None):
835        """Prepare and execute a database operation (query or command)."""
836        # The parameters may also be specified as list of tuples to e.g.
837        # insert multiple rows in a single operation, but this kind of
838        # usage is deprecated.  We make several plausibility checks because
839        # tuples can also be passed with the meaning of ROW constructors.
840        if (parameters and isinstance(parameters, list)
841                and len(parameters) > 1
842                and all(isinstance(p, tuple) for p in parameters)
843                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
844            return self.executemany(operation, parameters)
845        else:
846            # not a list of tuples
847            return self.executemany(operation, [parameters])
848
849    def executemany(self, operation, seq_of_parameters):
850        """Prepare operation and execute it against a parameter sequence."""
851        if not seq_of_parameters:
852            # don't do anything without parameters
853            return
854        self._description = None
855        self.rowcount = -1
856        # first try to execute all queries
857        rowcount = 0
858        sql = "BEGIN"
859        try:
860            if not self._dbcnx._tnx:
861                try:
862                    self._cnx.source().execute(sql)
863                except DatabaseError:
864                    raise  # database provides error message
865                except Exception as err:
866                    raise _op_error("Can't start transaction")
867                self._dbcnx._tnx = True
868            for parameters in seq_of_parameters:
869                sql = operation
870                if parameters:
871                    sql = self._quoteparams(sql, parameters)
872                rows = self._src.execute(sql)
873                if rows:  # true if not DML
874                    rowcount += rows
875                else:
876                    self.rowcount = -1
877        except DatabaseError:
878            raise  # database provides error message
879        except Error as err:
880            raise _db_error(
881                "Error in '%s': '%s' " % (sql, err), InterfaceError)
882        except Exception as err:
883            raise _op_error("Internal error in '%s': %s" % (sql, err))
884        # then initialize result raw count and description
885        if self._src.resulttype == RESULT_DQL:
886            self._description = True  # fetch on demand
887            self.rowcount = self._src.ntuples
888            self.lastrowid = None
889            if self.build_row_factory:
890                self.row_factory = self.build_row_factory()
891        else:
892            self.rowcount = rowcount
893            self.lastrowid = self._src.oidstatus()
894        # return the cursor object, so you can write statements such as
895        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
896        return self
897
898    def fetchone(self):
899        """Fetch the next row of a query result set."""
900        res = self.fetchmany(1, False)
901        try:
902            return res[0]
903        except IndexError:
904            return None
905
906    def fetchall(self):
907        """Fetch all (remaining) rows of a query result."""
908        return self.fetchmany(-1, False)
909
910    def fetchmany(self, size=None, keep=False):
911        """Fetch the next set of rows of a query result.
912
913        The number of rows to fetch per call is specified by the
914        size parameter. If it is not given, the cursor's arraysize
915        determines the number of rows to be fetched. If you set
916        the keep parameter to true, this is kept as new arraysize.
917        """
918        if size is None:
919            size = self.arraysize
920        if keep:
921            self.arraysize = size
922        try:
923            result = self._src.fetch(size)
924        except DatabaseError:
925            raise
926        except Error as err:
927            raise _db_error(str(err))
928        typecast = self.type_cache.typecast
929        return [self.row_factory([typecast(value, typ)
930            for typ, value in zip(self.coltypes, row)]) for row in result]
931
932    def callproc(self, procname, parameters=None):
933        """Call a stored database procedure with the given name.
934
935        The sequence of parameters must contain one entry for each input
936        argument that the procedure expects. The result of the call is the
937        same as this input sequence; replacement of output and input/output
938        parameters in the return value is currently not supported.
939
940        The procedure may also provide a result set as output. These can be
941        requested through the standard fetch methods of the cursor.
942        """
943        n = parameters and len(parameters) or 0
944        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
945        self.execute(query, parameters)
946        return parameters
947
948    def copy_from(self, stream, table,
949            format=None, sep=None, null=None, size=None, columns=None):
950        """Copy data from an input stream to the specified table.
951
952        The input stream can be a file-like object with a read() method or
953        it can also be an iterable returning a row or multiple rows of input
954        on each iteration.
955
956        The format must be text, csv or binary. The sep option sets the
957        column separator (delimiter) used in the non binary formats.
958        The null option sets the textual representation of NULL in the input.
959
960        The size option sets the size of the buffer used when reading data
961        from file-like objects.
962
963        The copy operation can be restricted to a subset of columns. If no
964        columns are specified, all of them will be copied.
965        """
966        binary_format = format == 'binary'
967        try:
968            read = stream.read
969        except AttributeError:
970            if size:
971                raise ValueError("Size must only be set for file-like objects")
972            if binary_format:
973                input_type = bytes
974                type_name = 'byte strings'
975            else:
976                input_type = basestring
977                type_name = 'strings'
978
979            if isinstance(stream, basestring):
980                if not isinstance(stream, input_type):
981                    raise ValueError("The input must be %s" % type_name)
982                if not binary_format:
983                    if isinstance(stream, str):
984                        if not stream.endswith('\n'):
985                            stream += '\n'
986                    else:
987                        if not stream.endswith(b'\n'):
988                            stream += b'\n'
989
990                def chunks():
991                    yield stream
992
993            elif isinstance(stream, Iterable):
994
995                def chunks():
996                    for chunk in stream:
997                        if not isinstance(chunk, input_type):
998                            raise ValueError(
999                                "Input stream must consist of %s" % type_name)
1000                        if isinstance(chunk, str):
1001                            if not chunk.endswith('\n'):
1002                                chunk += '\n'
1003                        else:
1004                            if not chunk.endswith(b'\n'):
1005                                chunk += b'\n'
1006                        yield chunk
1007
1008            else:
1009                raise TypeError("Need an input stream to copy from")
1010        else:
1011            if size is None:
1012                size = 8192
1013            elif not isinstance(size, int):
1014                raise TypeError("The size option must be an integer")
1015            if size > 0:
1016
1017                def chunks():
1018                    while True:
1019                        buffer = read(size)
1020                        yield buffer
1021                        if not buffer or len(buffer) < size:
1022                            break
1023
1024            else:
1025
1026                def chunks():
1027                    yield read()
1028
1029        if not table or not isinstance(table, basestring):
1030            raise TypeError("Need a table to copy to")
1031        if table.lower().startswith('select'):
1032                raise ValueError("Must specify a table, not a query")
1033        else:
1034            table = '"%s"' % (table,)
1035        operation = ['copy %s' % (table,)]
1036        options = []
1037        params = []
1038        if format is not None:
1039            if not isinstance(format, basestring):
1040                raise TypeError("The frmat option must be be a string")
1041            if format not in ('text', 'csv', 'binary'):
1042                raise ValueError("Invalid format")
1043            options.append('format %s' % (format,))
1044        if sep is not None:
1045            if not isinstance(sep, basestring):
1046                raise TypeError("The sep option must be a string")
1047            if format == 'binary':
1048                raise ValueError(
1049                    "The sep option is not allowed with binary format")
1050            if len(sep) != 1:
1051                raise ValueError(
1052                    "The sep option must be a single one-byte character")
1053            options.append('delimiter %s')
1054            params.append(sep)
1055        if null is not None:
1056            if not isinstance(null, basestring):
1057                raise TypeError("The null option must be a string")
1058            options.append('null %s')
1059            params.append(null)
1060        if columns:
1061            if not isinstance(columns, basestring):
1062                columns = ','.join('"%s"' % (col,) for col in columns)
1063            operation.append('(%s)' % (columns,))
1064        operation.append("from stdin")
1065        if options:
1066            operation.append('(%s)' % ','.join(options))
1067        operation = ' '.join(operation)
1068
1069        putdata = self._src.putdata
1070        self.execute(operation, params)
1071
1072        try:
1073            for chunk in chunks():
1074                putdata(chunk)
1075        except BaseException as error:
1076            self.rowcount = -1
1077            # the following call will re-raise the error
1078            putdata(error)
1079        else:
1080            self.rowcount = putdata(None)
1081
1082        # return the cursor object, so you can chain operations
1083        return self
1084
1085    def copy_to(self, stream, table,
1086            format=None, sep=None, null=None, decode=None, columns=None):
1087        """Copy data from the specified table to an output stream.
1088
1089        The output stream can be a file-like object with a write() method or
1090        it can also be None, in which case the method will return a generator
1091        yielding a row on each iteration.
1092
1093        Output will be returned as byte strings unless you set decode to true.
1094
1095        Note that you can also use a select query instead of the table name.
1096
1097        The format must be text, csv or binary. The sep option sets the
1098        column separator (delimiter) used in the non binary formats.
1099        The null option sets the textual representation of NULL in the output.
1100
1101        The copy operation can be restricted to a subset of columns. If no
1102        columns are specified, all of them will be copied.
1103        """
1104        binary_format = format == 'binary'
1105        if stream is not None:
1106            try:
1107                write = stream.write
1108            except AttributeError:
1109                raise TypeError("Need an output stream to copy to")
1110        if not table or not isinstance(table, basestring):
1111            raise TypeError("Need a table to copy to")
1112        if table.lower().startswith('select'):
1113            if columns:
1114                raise ValueError("Columns must be specified in the query")
1115            table = '(%s)' % (table,)
1116        else:
1117            table = '"%s"' % (table,)
1118        operation = ['copy %s' % (table,)]
1119        options = []
1120        params = []
1121        if format is not None:
1122            if not isinstance(format, basestring):
1123                raise TypeError("The format option must be a string")
1124            if format not in ('text', 'csv', 'binary'):
1125                raise ValueError("Invalid format")
1126            options.append('format %s' % (format,))
1127        if sep is not None:
1128            if not isinstance(sep, basestring):
1129                raise TypeError("The sep option must be a string")
1130            if binary_format:
1131                raise ValueError(
1132                    "The sep option is not allowed with binary format")
1133            if len(sep) != 1:
1134                raise ValueError(
1135                    "The sep option must be a single one-byte character")
1136            options.append('delimiter %s')
1137            params.append(sep)
1138        if null is not None:
1139            if not isinstance(null, basestring):
1140                raise TypeError("The null option must be a string")
1141            options.append('null %s')
1142            params.append(null)
1143        if decode is None:
1144            if format == 'binary':
1145                decode = False
1146            else:
1147                decode = str is unicode
1148        else:
1149            if not isinstance(decode, (int, bool)):
1150                raise TypeError("The decode option must be a boolean")
1151            if decode and binary_format:
1152                raise ValueError(
1153                    "The decode option is not allowed with binary format")
1154        if columns:
1155            if not isinstance(columns, basestring):
1156                columns = ','.join('"%s"' % (col,) for col in columns)
1157            operation.append('(%s)' % (columns,))
1158
1159        operation.append("to stdout")
1160        if options:
1161            operation.append('(%s)' % ','.join(options))
1162        operation = ' '.join(operation)
1163
1164        getdata = self._src.getdata
1165        self.execute(operation, params)
1166
1167        def copy():
1168            self.rowcount = 0
1169            while True:
1170                row = getdata(decode)
1171                if isinstance(row, int):
1172                    if self.rowcount != row:
1173                        self.rowcount = row
1174                    break
1175                self.rowcount += 1
1176                yield row
1177
1178        if stream is None:
1179            # no input stream, return the generator
1180            return copy()
1181
1182        # write the rows to the file-like input stream
1183        for row in copy():
1184            write(row)
1185
1186        # return the cursor object, so you can chain operations
1187        return self
1188
1189    def __next__(self):
1190        """Return the next row (support for the iteration protocol)."""
1191        res = self.fetchone()
1192        if res is None:
1193            raise StopIteration
1194        return res
1195
1196    # Note that since Python 2.6 the iterator protocol uses __next()__
1197    # instead of next(), we keep it only for backward compatibility of pgdb.
1198    next = __next__
1199
1200    @staticmethod
1201    def nextset():
1202        """Not supported."""
1203        raise NotSupportedError("The nextset() method is not supported")
1204
1205    @staticmethod
1206    def setinputsizes(sizes):
1207        """Not supported."""
1208        pass  # unsupported, but silently passed
1209
1210    @staticmethod
1211    def setoutputsize(size, column=0):
1212        """Not supported."""
1213        pass  # unsupported, but silently passed
1214
1215    @staticmethod
1216    def row_factory(row):
1217        """Process rows before they are returned.
1218
1219        You can overwrite this statically with a custom row factory, or
1220        you can build a row factory dynamically with build_row_factory().
1221
1222        For example, you can create a Cursor class that returns rows as
1223        Python dictionaries like this:
1224
1225            class DictCursor(pgdb.Cursor):
1226
1227                def row_factory(self, row):
1228                    return {desc[0]: value
1229                        for desc, value in zip(self.description, row)}
1230
1231            cur = DictCursor(con)  # get one DictCursor instance or
1232            con.cursor_type = DictCursor  # always use DictCursor instances
1233        """
1234        raise NotImplementedError
1235
1236    def build_row_factory(self):
1237        """Build a row factory based on the current description.
1238
1239        This implementation builds a row factory for creating named tuples.
1240        You can overwrite this method if you want to dynamically create
1241        different row factories whenever the column description changes.
1242        """
1243        colnames = self.colnames
1244        if colnames:
1245            try:
1246                try:
1247                    return namedtuple('Row', colnames, rename=True)._make
1248                except TypeError:  # Python 2.6 and 3.0 do not support rename
1249                    colnames = [v if v.isalnum() else 'column_%d' % n
1250                             for n, v in enumerate(colnames)]
1251                    return namedtuple('Row', colnames)._make
1252            except ValueError:  # there is still a problem with the field names
1253                colnames = ['column_%d' % n for n in range(len(colnames))]
1254                return namedtuple('Row', colnames)._make
1255
1256
1257CursorDescription = namedtuple('CursorDescription',
1258    ['name', 'type_code', 'display_size', 'internal_size',
1259     'precision', 'scale', 'null_ok'])
1260
1261
1262### Connection Objects
1263
1264class Connection(object):
1265    """Connection object."""
1266
1267    # expose the exceptions as attributes on the connection object
1268    Error = Error
1269    Warning = Warning
1270    InterfaceError = InterfaceError
1271    DatabaseError = DatabaseError
1272    InternalError = InternalError
1273    OperationalError = OperationalError
1274    ProgrammingError = ProgrammingError
1275    IntegrityError = IntegrityError
1276    DataError = DataError
1277    NotSupportedError = NotSupportedError
1278
1279    def __init__(self, cnx):
1280        """Create a database connection object."""
1281        self._cnx = cnx  # connection
1282        self._tnx = False  # transaction state
1283        self.type_cache = TypeCache(cnx)
1284        self.cursor_type = Cursor
1285        try:
1286            self._cnx.source()
1287        except Exception:
1288            raise _op_error("Invalid connection")
1289
1290    def __enter__(self):
1291        """Enter the runtime context for the connection object.
1292
1293        The runtime context can be used for running transactions.
1294        """
1295        return self
1296
1297    def __exit__(self, et, ev, tb):
1298        """Exit the runtime context for the connection object.
1299
1300        This does not close the connection, but it ends a transaction.
1301        """
1302        if et is None and ev is None and tb is None:
1303            self.commit()
1304        else:
1305            self.rollback()
1306
1307    def close(self):
1308        """Close the connection object."""
1309        if self._cnx:
1310            if self._tnx:
1311                try:
1312                    self.rollback()
1313                except DatabaseError:
1314                    pass
1315            self._cnx.close()
1316            self._cnx = None
1317        else:
1318            raise _op_error("Connection has been closed")
1319
1320    def commit(self):
1321        """Commit any pending transaction to the database."""
1322        if self._cnx:
1323            if self._tnx:
1324                self._tnx = False
1325                try:
1326                    self._cnx.source().execute("COMMIT")
1327                except DatabaseError:
1328                    raise
1329                except Exception:
1330                    raise _op_error("Can't commit")
1331        else:
1332            raise _op_error("Connection has been closed")
1333
1334    def rollback(self):
1335        """Roll back to the start of any pending transaction."""
1336        if self._cnx:
1337            if self._tnx:
1338                self._tnx = False
1339                try:
1340                    self._cnx.source().execute("ROLLBACK")
1341                except DatabaseError:
1342                    raise
1343                except Exception:
1344                    raise _op_error("Can't rollback")
1345        else:
1346            raise _op_error("Connection has been closed")
1347
1348    def cursor(self):
1349        """Return a new cursor object using the connection."""
1350        if self._cnx:
1351            try:
1352                return self.cursor_type(self)
1353            except Exception:
1354                raise _op_error("Invalid connection")
1355        else:
1356            raise _op_error("Connection has been closed")
1357
1358    if shortcutmethods:  # otherwise do not implement and document this
1359
1360        def execute(self, operation, params=None):
1361            """Shortcut method to run an operation on an implicit cursor."""
1362            cursor = self.cursor()
1363            cursor.execute(operation, params)
1364            return cursor
1365
1366        def executemany(self, operation, param_seq):
1367            """Shortcut method to run an operation against a sequence."""
1368            cursor = self.cursor()
1369            cursor.executemany(operation, param_seq)
1370            return cursor
1371
1372
1373### Module Interface
1374
1375_connect_ = connect
1376
1377def connect(dsn=None,
1378        user=None, password=None,
1379        host=None, database=None):
1380    """Connect to a database."""
1381    # first get params from DSN
1382    dbport = -1
1383    dbhost = ""
1384    dbbase = ""
1385    dbuser = ""
1386    dbpasswd = ""
1387    dbopt = ""
1388    try:
1389        params = dsn.split(":")
1390        dbhost = params[0]
1391        dbbase = params[1]
1392        dbuser = params[2]
1393        dbpasswd = params[3]
1394        dbopt = params[4]
1395    except (AttributeError, IndexError, TypeError):
1396        pass
1397
1398    # override if necessary
1399    if user is not None:
1400        dbuser = user
1401    if password is not None:
1402        dbpasswd = password
1403    if database is not None:
1404        dbbase = database
1405    if host is not None:
1406        try:
1407            params = host.split(":")
1408            dbhost = params[0]
1409            dbport = int(params[1])
1410        except (AttributeError, IndexError, TypeError, ValueError):
1411            pass
1412
1413    # empty host is localhost
1414    if dbhost == "":
1415        dbhost = None
1416    if dbuser == "":
1417        dbuser = None
1418
1419    # open the connection
1420    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
1421    return Connection(cnx)
1422
1423
1424### Types Handling
1425
1426class Type(frozenset):
1427    """Type class for a couple of PostgreSQL data types.
1428
1429    PostgreSQL is object-oriented: types are dynamic.
1430    We must thus use type names as internal type codes.
1431    """
1432
1433    def __new__(cls, values):
1434        if isinstance(values, basestring):
1435            values = values.split()
1436        return super(Type, cls).__new__(cls, values)
1437
1438    def __eq__(self, other):
1439        if isinstance(other, basestring):
1440            if other.startswith('_'):
1441                other = other[1:]
1442            return other in self
1443        else:
1444            return super(Type, self).__eq__(other)
1445
1446    def __ne__(self, other):
1447        if isinstance(other, basestring):
1448            if other.startswith('_'):
1449                other = other[1:]
1450            return other not in self
1451        else:
1452            return super(Type, self).__ne__(other)
1453
1454
1455class ArrayType:
1456    """Type class for PostgreSQL array types."""
1457
1458    def __eq__(self, other):
1459        if isinstance(other, basestring):
1460            return other.startswith('_')
1461        else:
1462            return isinstance(other, ArrayType)
1463
1464    def __ne__(self, other):
1465        if isinstance(other, basestring):
1466            return not other.startswith('_')
1467        else:
1468            return not isinstance(other, ArrayType)
1469
1470
1471class RecordType:
1472    """Type class for PostgreSQL record types."""
1473
1474    def __eq__(self, other):
1475        if isinstance(other, TypeCode):
1476            return other.type == 'c'
1477        elif isinstance(other, basestring):
1478            return other == 'record'
1479        else:
1480            return isinstance(other, RecordType)
1481
1482    def __ne__(self, other):
1483        if isinstance(other, TypeCode):
1484            return other.type != 'c'
1485        elif isinstance(other, basestring):
1486            return other != 'record'
1487        else:
1488            return not isinstance(other, RecordType)
1489
1490
1491# Mandatory type objects defined by DB-API 2 specs:
1492
1493STRING = Type('char bpchar name text varchar')
1494BINARY = Type('bytea')
1495NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1496DATETIME = Type('date time timetz timestamp timestamptz interval'
1497    ' abstime reltime')  # these are very old
1498ROWID = Type('oid')
1499
1500
1501# Additional type objects (more specific):
1502
1503BOOL = Type('bool')
1504SMALLINT = Type('int2')
1505INTEGER = Type('int2 int4 int8 serial')
1506LONG = Type('int8')
1507FLOAT = Type('float4 float8')
1508NUMERIC = Type('numeric')
1509MONEY = Type('money')
1510DATE = Type('date')
1511TIME = Type('time timetz')
1512TIMESTAMP = Type('timestamp timestamptz')
1513INTERVAL = Type('interval')
1514HSTORE = Type('hstore')
1515JSON = Type('json jsonb')
1516
1517# Type object for arrays (also equate to their base types):
1518
1519ARRAY = ArrayType()
1520
1521# Type object for records (encompassing all composite types):
1522
1523RECORD = RecordType()
1524
1525
1526# Mandatory type helpers defined by DB-API 2 specs:
1527
1528def Date(year, month, day):
1529    """Construct an object holding a date value."""
1530    return date(year, month, day)
1531
1532
1533def Time(hour, minute=0, second=0, microsecond=0):
1534    """Construct an object holding a time value."""
1535    return time(hour, minute, second, microsecond)
1536
1537
1538def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1539    """Construct an object holding a time stamp value."""
1540    return datetime(year, month, day, hour, minute, second, microsecond)
1541
1542
1543def DateFromTicks(ticks):
1544    """Construct an object holding a date value from the given ticks value."""
1545    return Date(*localtime(ticks)[:3])
1546
1547
1548def TimeFromTicks(ticks):
1549    """Construct an object holding a time value from the given ticks value."""
1550    return Time(*localtime(ticks)[3:6])
1551
1552
1553def TimestampFromTicks(ticks):
1554    """Construct an object holding a time stamp from the given ticks value."""
1555    return Timestamp(*localtime(ticks)[:6])
1556
1557
1558class Binary(bytes):
1559    """Construct an object capable of holding a binary (long) string value."""
1560
1561
1562# Additional type helpers for PyGreSQL:
1563
1564def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0):
1565    """Construct an object holding a time inverval value."""
1566    return timedelta(days, hours=hours, minutes=minutes, seconds=seconds,
1567        microseconds=microseconds)
1568
1569
1570class Hstore(dict):
1571    """Wrapper class for marking hstore values."""
1572
1573    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
1574
1575    @classmethod
1576    def _quote(cls, s):
1577        if s is None:
1578            return 'NULL'
1579        if not s:
1580            return '""'
1581        s = s.replace('"', '\\"')
1582        if cls._re_quote.search(s):
1583            s = '"%s"' % s
1584        return s
1585
1586    def __str__(self):
1587        q = self._quote
1588        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
1589
1590
1591class Json:
1592    """Construct a wrapper for holding an object serializable to JSON."""
1593
1594    def __init__(self, obj, encode=None):
1595        self.obj = obj
1596        self.encode = encode or jsonencode
1597
1598    def __str__(self):
1599        obj = self.obj
1600        if isinstance(obj, basestring):
1601            return obj
1602        return self.encode(obj)
1603
1604
1605class Literal:
1606    """Construct a wrapper for holding a literal SQL string."""
1607
1608    def __init__(self, sql):
1609        self.sql = sql
1610
1611    def __str__(self):
1612        return self.sql
1613
1614    __pg_repr__ = __str__
1615
1616# If run as script, print some information:
1617
1618if __name__ == '__main__':
1619    print('PyGreSQL version', version)
1620    print('')
1621    print(__doc__)
Note: See TracBrowser for help on using the repository browser.