source: trunk/pgdb.py @ 815

Last change on this file since 815 was 815, checked in by cito, 4 years ago

PEP8 recommends not assigning lambda expressions

  • Property svn:keywords set to Id
File size: 52.4 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 815 2016-02-03 21:11:13Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from functools import partial
76from re import compile as regex
77from json import loads as jsondecode, dumps as jsonencode
78
79try:
80    long
81except NameError:  # Python >= 3.0
82    long = int
83
84try:
85    unicode
86except NameError:  # Python >= 3.0
87    unicode = str
88
89try:
90    basestring
91except NameError:  # Python >= 3.0
92    basestring = (str, bytes)
93
94from collections import Iterable
95
96
97### Module Constants
98
99# compliant with DB API 2.0
100apilevel = '2.0'
101
102# module may be shared, but not connections
103threadsafety = 1
104
105# this module use extended python format codes
106paramstyle = 'pyformat'
107
108# shortcut methods have been excluded from DB API 2 and
109# are not recommended by the DB SIG, but they can be handy
110shortcutmethods = 1
111
112
113### Internal Type Handling
114
115try:
116    from inspect import signature
117except ImportError:  # Python < 3.3
118    from inspect import getargspec
119
120    def get_args(func):
121        return getargspec(func).args
122else:
123
124    def get_args(func):
125        return list(signature(func).parameters)
126
127try:
128    if datetime.strptime('+0100', '%z') is None:
129        raise ValueError
130except ValueError:  # Python < 3.2
131    timezones = None
132else:
133    # time zones used in Postgres timestamptz output
134    timezones = dict(CET='+0100', EET='+0200', EST='-0500',
135        GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
136        UCT='+0000', UTC='+0000', WET='+0000')
137
138
139def decimal_type(decimal_type=None):
140    """Get or set global type to be used for decimal values.
141
142    Note that connections cache cast functions. To be sure a global change
143    is picked up by a running connection, call con.type_cache.reset_typecast().
144    """
145    global Decimal
146    if decimal_type is not None:
147        Decimal = decimal_type
148        set_typecast('numeric', decimal_type)
149    return Decimal
150
151
152def cast_bool(value):
153    """Cast boolean value in database format to bool."""
154    if value:
155        return value[0] in ('t', 'T')
156
157
158def cast_money(value):
159    """Cast money value in database format to Decimal."""
160    if value:
161        value = value.replace('(', '-')
162        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
163
164
165def cast_int2vector(value):
166    """Cast an int2vector value."""
167    return [int(v) for v in value.split()]
168
169
170def cast_date(value, connection):
171    """Cast a date value."""
172    # The output format depends on the server setting DateStyle.  The default
173    # setting ISO and the setting for German are actually unambiguous.  The
174    # order of days and months in the other two settings is however ambiguous,
175    # so at least here we need to consult the setting to properly parse values.
176    if value == '-infinity':
177        return date.min
178    if value == 'infinity':
179        return date.max
180    value = value.split()
181    if value[-1] == 'BC':
182        return date.min
183    value = value[0]
184    if len(value) > 10:
185        return date.max
186    fmt = connection.date_format()
187    return datetime.strptime(value, fmt).date()
188
189
190def cast_time(value):
191    """Cast a time value."""
192    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
193    return datetime.strptime(value, fmt).time()
194
195
196_re_timezone = regex('(.*)([+-].*)')
197
198
199def cast_timetz(value):
200    """Cast a timetz value."""
201    tz = _re_timezone.match(value)
202    if tz:
203        value, tz = tz.groups()
204    else:
205        tz = '+0000'
206    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
207    if timezones:
208        if tz.startswith(('+', '-')):
209            if len(tz) < 5:
210                tz += '00'
211            else:
212                tz = tz.replace(':', '')
213        elif tz in timezones:
214            tz = timezones[tz]
215        else:
216            tz = '+0000'
217        value += tz
218        fmt += '%z'
219    return datetime.strptime(value, fmt).timetz()
220
221
222def cast_timestamp(value, connection):
223    """Cast a timestamp value."""
224    if value == '-infinity':
225        return datetime.min
226    if value == 'infinity':
227        return datetime.max
228    value = value.split()
229    if value[-1] == 'BC':
230        return datetime.min
231    fmt = connection.date_format()
232    if fmt.endswith('-%Y') and len(value) > 2:
233        value = value[1:5]
234        if len(value[3]) > 4:
235            return datetime.max
236        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
237            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
238    else:
239        if len(value[0]) > 10:
240            return datetime.max
241        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
242    return datetime.strptime(' '.join(value), ' '.join(fmt))
243
244
245def cast_timestamptz(value, connection):
246    """Cast a timestamptz value."""
247    if value == '-infinity':
248        return datetime.min
249    if value == 'infinity':
250        return datetime.max
251    value = value.split()
252    if value[-1] == 'BC':
253        return datetime.min
254    fmt = connection.date_format()
255    if fmt.endswith('-%Y') and len(value) > 2:
256        value = value[1:]
257        if len(value[3]) > 4:
258            return datetime.max
259        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
260            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
261        value, tz = value[:-1], value[-1]
262    else:
263        if fmt.startswith('%Y-'):
264            tz = _re_timezone.match(value[1])
265            if tz:
266                value[1], tz = tz.groups()
267            else:
268                tz = '+0000'
269        else:
270            value, tz = value[:-1], value[-1]
271        if len(value[0]) > 10:
272            return datetime.max
273        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
274    if timezones:
275        if tz.startswith(('+', '-')):
276            if len(tz) < 5:
277                tz += '00'
278            else:
279                tz = tz.replace(':', '')
280        elif tz in timezones:
281            tz = timezones[tz]
282        else:
283            tz = '+0000'
284        value.append(tz)
285        fmt.append('%z')
286    return datetime.strptime(' '.join(value), ' '.join(fmt))
287
288_re_interval_sql_standard = regex(
289    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
290    '(?:([+-]?[0-9]+)(?!:) ?)?'
291    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
292
293_re_interval_postgres = regex(
294    '(?:([+-]?[0-9]+) ?years? ?)?'
295    '(?:([+-]?[0-9]+) ?mons? ?)?'
296    '(?:([+-]?[0-9]+) ?days? ?)?'
297    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
298
299_re_interval_postgres_verbose = regex(
300    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
301    '(?:([+-]?[0-9]+) ?mons? ?)?'
302    '(?:([+-]?[0-9]+) ?days? ?)?'
303    '(?:([+-]?[0-9]+) ?hours? ?)?'
304    '(?:([+-]?[0-9]+) ?mins? ?)?'
305    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
306
307_re_interval_iso_8601 = regex(
308    'P(?:([+-]?[0-9]+)Y)?'
309    '(?:([+-]?[0-9]+)M)?'
310    '(?:([+-]?[0-9]+)D)?'
311    '(?:T(?:([+-]?[0-9]+)H)?'
312    '(?:([+-]?[0-9]+)M)?'
313    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
314
315
316def cast_interval(value):
317    """Cast an interval value."""
318    # The output format depends on the server setting IntervalStyle, but it's
319    # not necessary to consult this setting to parse it.  It's faster to just
320    # check all possible formats, and there is no ambiguity here.
321    m = _re_interval_iso_8601.match(value)
322    if m:
323        m = [d or '0' for d in m.groups()]
324        secs_ago = m.pop(5) == '-'
325        m = [int(d) for d in m]
326        years, mons, days, hours, mins, secs, usecs = m
327        if secs_ago:
328            secs = -secs
329            usecs = -usecs
330    else:
331        m = _re_interval_postgres_verbose.match(value)
332        if m:
333            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
334            secs_ago = m.pop(5) == '-'
335            m = [-int(d) for d in m] if ago else [int(d) for d in m]
336            years, mons, days, hours, mins, secs, usecs = m
337            if secs_ago:
338                secs = - secs
339                usecs = -usecs
340        else:
341            m = _re_interval_postgres.match(value)
342            if m and any(m.groups()):
343                m = [d or '0' for d in m.groups()]
344                hours_ago = m.pop(3) == '-'
345                m = [int(d) for d in m]
346                years, mons, days, hours, mins, secs, usecs = m
347                if hours_ago:
348                    hours = -hours
349                    mins = -mins
350                    secs = -secs
351                    usecs = -usecs
352            else:
353                m = _re_interval_sql_standard.match(value)
354                if m and any(m.groups()):
355                    m = [d or '0' for d in m.groups()]
356                    years_ago = m.pop(0) == '-'
357                    hours_ago = m.pop(3) == '-'
358                    m = [int(d) for d in m]
359                    years, mons, days, hours, mins, secs, usecs = m
360                    if years_ago:
361                        years = -years
362                        mons = -mons
363                    if hours_ago:
364                        hours = -hours
365                        mins = -mins
366                        secs = -secs
367                        usecs = -usecs
368                else:
369                    raise ValueError('Cannot parse interval: %s' % value)
370    days += 365 * years + 30 * mons
371    return timedelta(days=days, hours=hours, minutes=mins,
372        seconds=secs, microseconds=usecs)
373
374
375class Typecasts(dict):
376    """Dictionary mapping database types to typecast functions.
377
378    The cast functions get passed the string representation of a value in
379    the database which they need to convert to a Python object.  The
380    passed string will never be None since NULL values are already be
381    handled before the cast function is called.
382    """
383
384    # the default cast functions
385    # (str functions are ignored but have been added for faster access)
386    defaults = {'char': str, 'bpchar': str, 'name': str,
387        'text': str, 'varchar': str,
388        'bool': cast_bool, 'bytea': unescape_bytea,
389        'int2': int, 'int4': int, 'serial': int,
390        'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
391        'oid': long, 'oid8': long,
392        'float4': float, 'float8': float,
393        'numeric': Decimal, 'money': cast_money,
394        'date': cast_date, 'interval': cast_interval,
395        'time': cast_time, 'timetz': cast_timetz,
396        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
397        'int2vector': cast_int2vector,
398        'anyarray': cast_array, 'record': cast_record}
399
400    connection = None  # will be set in local connection specific instances
401
402    def __missing__(self, typ):
403        """Create a cast function if it is not cached.
404
405        Note that this class never raises a KeyError,
406        but returns None when no special cast function exists.
407        """
408        if not isinstance(typ, str):
409            raise TypeError('Invalid type: %s' % typ)
410        cast = self.defaults.get(typ)
411        if cast:
412            # store default for faster access
413            cast = self._add_connection(cast)
414            self[typ] = cast
415        elif typ.startswith('_'):
416            # create array cast
417            base_cast = self[typ[1:]]
418            cast = self.create_array_cast(base_cast)
419            if base_cast:
420                # store only if base type exists
421                self[typ] = cast
422        return cast
423
424    @staticmethod
425    def _needs_connection(func):
426        """Check if a typecast function needs a connection argument."""
427        try:
428            args = get_args(func)
429        except (TypeError, ValueError):
430            return False
431        else:
432            return 'connection' in args[1:]
433
434    def _add_connection(self, cast):
435        """Add a connection argument to the typecast function if necessary."""
436        if not self.connection or not self._needs_connection(cast):
437            return cast
438        return partial(cast, connection=self.connection)
439
440    def get(self, typ, default=None):
441        """Get the typecast function for the given database type."""
442        return self[typ] or default
443
444    def set(self, typ, cast):
445        """Set a typecast function for the specified database type(s)."""
446        if isinstance(typ, basestring):
447            typ = [typ]
448        if cast is None:
449            for t in typ:
450                self.pop(t, None)
451                self.pop('_%s' % t, None)
452        else:
453            if not callable(cast):
454                raise TypeError("Cast parameter must be callable")
455            for t in typ:
456                self[t] = self._add_connection(cast)
457                self.pop('_%s' % t, None)
458
459    def reset(self, typ=None):
460        """Reset the typecasts for the specified type(s) to their defaults.
461
462        When no type is specified, all typecasts will be reset.
463        """
464        defaults = self.defaults
465        if typ is None:
466            self.clear()
467            self.update(defaults)
468        else:
469            if isinstance(typ, basestring):
470                typ = [typ]
471            for t in typ:
472                cast = defaults.get(t)
473                if cast:
474                    self[t] = self._add_connection(cast)
475                    t = '_%s' % t
476                    cast = defaults.get(t)
477                    if cast:
478                        self[t] = self._add_connection(cast)
479                    else:
480                        self.pop(t, None)
481                else:
482                    self.pop(t, None)
483                    self.pop('_%s' % t, None)
484
485    def create_array_cast(self, basecast):
486        """Create an array typecast for the given base cast."""
487        def cast(v):
488            return cast_array(v, basecast)
489        return cast
490
491    def create_record_cast(self, name, fields, casts):
492        """Create a named record typecast for the given fields and casts."""
493        record = namedtuple(name, fields)
494        def cast(v):
495            return record(*cast_record(v, casts))
496        return cast
497
498
499_typecasts = Typecasts()  # this is the global typecast dictionary
500
501
502def get_typecast(typ):
503    """Get the global typecast function for the given database type(s)."""
504    return _typecasts.get(typ)
505
506
507def set_typecast(typ, cast):
508    """Set a global typecast function for the given database type(s).
509
510    Note that connections cache cast functions. To be sure a global change
511    is picked up by a running connection, call con.type_cache.reset_typecast().
512    """
513    _typecasts.set(typ, cast)
514
515
516def reset_typecast(typ=None):
517    """Reset the global typecasts for the given type(s) to their default.
518
519    When no type is specified, all typecasts will be reset.
520
521    Note that connections cache cast functions. To be sure a global change
522    is picked up by a running connection, call con.type_cache.reset_typecast().
523    """
524    _typecasts.reset(typ)
525
526
527class LocalTypecasts(Typecasts):
528    """Map typecasts, including local composite types, to cast functions."""
529
530    defaults = _typecasts
531
532    connection = None  # will be set in a connection specific instance
533
534    def __missing__(self, typ):
535        """Create a cast function if it is not cached."""
536        if typ.startswith('_'):
537            base_cast = self[typ[1:]]
538            cast = self.create_array_cast(base_cast)
539            if base_cast:
540                self[typ] = cast
541        else:
542            cast = self.defaults.get(typ)
543            if cast:
544                cast = self._add_connection(cast)
545                self[typ] = cast
546            else:
547                fields = self.get_fields(typ)
548                if fields:
549                    casts = [self[field.type] for field in fields]
550                    fields = [field.name for field in fields]
551                    cast = self.create_record_cast(typ, fields, casts)
552                    self[typ] = cast
553        return cast
554
555    def get_fields(self, typ):
556        """Return the fields for the given record type.
557
558        This method will be replaced with a method that looks up the fields
559        using the type cache of the connection.
560        """
561        return []
562
563
564class TypeCode(str):
565    """Class representing the type_code used by the DB-API 2.0.
566
567    TypeCode objects are strings equal to the PostgreSQL type name,
568    but carry some additional information.
569    """
570
571    @classmethod
572    def create(cls, oid, name, len, type, category, delim, relid):
573        """Create a type code for a PostgreSQL data type."""
574        self = cls(name)
575        self.oid = oid
576        self.len = len
577        self.type = type
578        self.category = category
579        self.delim = delim
580        self.relid = relid
581        return self
582
583FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
584
585
586class TypeCache(dict):
587    """Cache for database types.
588
589    This cache maps type OIDs and names to TypeCode strings containing
590    important information on the associated database type.
591    """
592
593    def __init__(self, cnx):
594        """Initialize type cache for connection."""
595        super(TypeCache, self).__init__()
596        self._escape_string = cnx.escape_string
597        self._src = cnx.source()
598        self._typecasts = LocalTypecasts()
599        self._typecasts.get_fields = self.get_fields
600        self._typecasts.connection = cnx
601
602    def __missing__(self, key):
603        """Get the type info from the database if it is not cached."""
604        if isinstance(key, int):
605            oid = key
606        else:
607            if '.' not in key and '"' not in key:
608                key = '"%s"' % key
609            oid = "'%s'::regtype" % self._escape_string(key)
610        try:
611            self._src.execute("SELECT oid, typname,"
612                 " typlen, typtype, typcategory, typdelim, typrelid"
613                " FROM pg_type WHERE oid=%s" % oid)
614        except ProgrammingError:
615            res = None
616        else:
617            res = self._src.fetch(1)
618        if not res:
619            raise KeyError('Type %s could not be found' % key)
620        res = res[0]
621        type_code = TypeCode.create(int(res[0]), res[1],
622            int(res[2]), res[3], res[4], res[5], int(res[6]))
623        self[type_code.oid] = self[str(type_code)] = type_code
624        return type_code
625
626    def get(self, key, default=None):
627        """Get the type even if it is not cached."""
628        try:
629            return self[key]
630        except KeyError:
631            return default
632
633    def get_fields(self, typ):
634        """Get the names and types of the fields of composite types."""
635        if not isinstance(typ, TypeCode):
636            typ = self.get(typ)
637            if not typ:
638                return None
639        if not typ.relid:
640            return None  # this type is not composite
641        self._src.execute("SELECT attname, atttypid"
642            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
643            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
644        return [FieldInfo(name, self.get(int(oid)))
645            for name, oid in self._src.fetch(-1)]
646
647    def get_typecast(self, typ):
648        """Get the typecast function for the given database type."""
649        return self._typecasts.get(typ)
650
651    def set_typecast(self, typ, cast):
652        """Set a typecast function for the specified database type(s)."""
653        self._typecasts.set(typ, cast)
654
655    def reset_typecast(self, typ=None):
656        """Reset the typecast function for the specified database type(s)."""
657        self._typecasts.reset(typ)
658
659    def typecast(self, value, typ):
660        """Cast the given value according to the given database type."""
661        if value is None:
662            # for NULL values, no typecast is necessary
663            return None
664        cast = self.get_typecast(typ)
665        if not cast or cast is str:
666            # no typecast is necessary
667            return value
668        return cast(value)
669
670
671class _quotedict(dict):
672    """Dictionary with auto quoting of its items.
673
674    The quote attribute must be set to the desired quote function.
675    """
676
677    def __getitem__(self, key):
678        return self.quote(super(_quotedict, self).__getitem__(key))
679
680
681### Error messages
682
683def _db_error(msg, cls=DatabaseError):
684    """Return DatabaseError with empty sqlstate attribute."""
685    error = cls(msg)
686    error.sqlstate = None
687    return error
688
689
690def _op_error(msg):
691    """Return OperationalError."""
692    return _db_error(msg, OperationalError)
693
694
695### Cursor Object
696
697class Cursor(object):
698    """Cursor object."""
699
700    def __init__(self, dbcnx):
701        """Create a cursor object for the database connection."""
702        self.connection = self._dbcnx = dbcnx
703        self._cnx = dbcnx._cnx
704        self.type_cache = dbcnx.type_cache
705        self._src = self._cnx.source()
706        # the official attribute for describing the result columns
707        self._description = None
708        if self.row_factory is Cursor.row_factory:
709            # the row factory needs to be determined dynamically
710            self.row_factory = None
711        else:
712            self.build_row_factory = None
713        self.rowcount = -1
714        self.arraysize = 1
715        self.lastrowid = None
716
717    def __iter__(self):
718        """Make cursor compatible to the iteration protocol."""
719        return self
720
721    def __enter__(self):
722        """Enter the runtime context for the cursor object."""
723        return self
724
725    def __exit__(self, et, ev, tb):
726        """Exit the runtime context for the cursor object."""
727        self.close()
728
729    def _quote(self, value):
730        """Quote value depending on its type."""
731        if value is None:
732            return 'NULL'
733        if isinstance(value, (datetime, date, time, timedelta, Json)):
734            value = str(value)
735        if isinstance(value, basestring):
736            if isinstance(value, Binary):
737                value = self._cnx.escape_bytea(value)
738                if bytes is not str:  # Python >= 3.0
739                    value = value.decode('ascii')
740            else:
741                value = self._cnx.escape_string(value)
742            return "'%s'" % value
743        if isinstance(value, float):
744            if isinf(value):
745                return "'-Infinity'" if value < 0 else "'Infinity'"
746            if isnan(value):
747                return "'NaN'"
748            return value
749        if isinstance(value, (int, long, Decimal, Literal)):
750            return value
751        if isinstance(value, list):
752            # Quote value as an ARRAY constructor. This is better than using
753            # an array literal because it carries the information that this is
754            # an array and not a string.  One issue with this syntax is that
755            # you need to add an explicit typecast when passing empty arrays.
756            # The ARRAY keyword is actually only necessary at the top level.
757            q = self._quote
758            return 'ARRAY[%s]' % ','.join(str(q(v)) for v in value)
759        if isinstance(value, tuple):
760            # Quote as a ROW constructor.  This is better than using a record
761            # literal because it carries the information that this is a record
762            # and not a string.  We don't use the keyword ROW in order to make
763            # this usable with the IN syntax as well.  It is only necessary
764            # when the records has a single column which is not really useful.
765            q = self._quote
766            return '(%s)' % ','.join(str(q(v)) for v in value)
767        try:
768            value = value.__pg_repr__()
769        except AttributeError:
770            raise InterfaceError(
771                'Do not know how to adapt type %s' % type(value))
772        if isinstance(value, (tuple, list)):
773            value = self._quote(value)
774        return value
775
776    def _quoteparams(self, string, parameters):
777        """Quote parameters.
778
779        This function works for both mappings and sequences.
780        """
781        if isinstance(parameters, dict):
782            parameters = _quotedict(parameters)
783            parameters.quote = self._quote
784        else:
785            parameters = tuple(map(self._quote, parameters))
786        return string % parameters
787
788    def _make_description(self, info):
789        """Make the description tuple for the given field info."""
790        name, typ, size, mod = info[1:]
791        type_code = self.type_cache[typ]
792        if mod > 0:
793            mod -= 4
794        if type_code == 'numeric':
795            precision, scale = mod >> 16, mod & 0xffff
796            size = precision
797        else:
798            if not size:
799                size = type_code.size
800            if size == -1:
801                size = mod
802            precision = scale = None
803        return CursorDescription(name, type_code,
804            None, size, precision, scale, None)
805
806    @property
807    def description(self):
808        """Read-only attribute describing the result columns."""
809        descr = self._description
810        if self._description is True:
811            make = self._make_description
812            descr = [make(info) for info in self._src.listinfo()]
813            self._description = descr
814        return descr
815
816    @property
817    def colnames(self):
818        """Unofficial convenience method for getting the column names."""
819        return [d[0] for d in self.description]
820
821    @property
822    def coltypes(self):
823        """Unofficial convenience method for getting the column types."""
824        return [d[1] for d in self.description]
825
826    def close(self):
827        """Close the cursor object."""
828        self._src.close()
829        self._description = None
830        self.rowcount = -1
831        self.lastrowid = None
832
833    def execute(self, operation, parameters=None):
834        """Prepare and execute a database operation (query or command)."""
835        # The parameters may also be specified as list of tuples to e.g.
836        # insert multiple rows in a single operation, but this kind of
837        # usage is deprecated.  We make several plausibility checks because
838        # tuples can also be passed with the meaning of ROW constructors.
839        if (parameters and isinstance(parameters, list)
840                and len(parameters) > 1
841                and all(isinstance(p, tuple) for p in parameters)
842                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
843            return self.executemany(operation, parameters)
844        else:
845            # not a list of tuples
846            return self.executemany(operation, [parameters])
847
848    def executemany(self, operation, seq_of_parameters):
849        """Prepare operation and execute it against a parameter sequence."""
850        if not seq_of_parameters:
851            # don't do anything without parameters
852            return
853        self._description = None
854        self.rowcount = -1
855        # first try to execute all queries
856        rowcount = 0
857        sql = "BEGIN"
858        try:
859            if not self._dbcnx._tnx:
860                try:
861                    self._cnx.source().execute(sql)
862                except DatabaseError:
863                    raise  # database provides error message
864                except Exception as err:
865                    raise _op_error("Can't start transaction")
866                self._dbcnx._tnx = True
867            for parameters in seq_of_parameters:
868                sql = operation
869                if parameters:
870                    sql = self._quoteparams(sql, parameters)
871                rows = self._src.execute(sql)
872                if rows:  # true if not DML
873                    rowcount += rows
874                else:
875                    self.rowcount = -1
876        except DatabaseError:
877            raise  # database provides error message
878        except Error as err:
879            raise _db_error(
880                "Error in '%s': '%s' " % (sql, err), InterfaceError)
881        except Exception as err:
882            raise _op_error("Internal error in '%s': %s" % (sql, err))
883        # then initialize result raw count and description
884        if self._src.resulttype == RESULT_DQL:
885            self._description = True  # fetch on demand
886            self.rowcount = self._src.ntuples
887            self.lastrowid = None
888            if self.build_row_factory:
889                self.row_factory = self.build_row_factory()
890        else:
891            self.rowcount = rowcount
892            self.lastrowid = self._src.oidstatus()
893        # return the cursor object, so you can write statements such as
894        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
895        return self
896
897    def fetchone(self):
898        """Fetch the next row of a query result set."""
899        res = self.fetchmany(1, False)
900        try:
901            return res[0]
902        except IndexError:
903            return None
904
905    def fetchall(self):
906        """Fetch all (remaining) rows of a query result."""
907        return self.fetchmany(-1, False)
908
909    def fetchmany(self, size=None, keep=False):
910        """Fetch the next set of rows of a query result.
911
912        The number of rows to fetch per call is specified by the
913        size parameter. If it is not given, the cursor's arraysize
914        determines the number of rows to be fetched. If you set
915        the keep parameter to true, this is kept as new arraysize.
916        """
917        if size is None:
918            size = self.arraysize
919        if keep:
920            self.arraysize = size
921        try:
922            result = self._src.fetch(size)
923        except DatabaseError:
924            raise
925        except Error as err:
926            raise _db_error(str(err))
927        typecast = self.type_cache.typecast
928        return [self.row_factory([typecast(value, typ)
929            for typ, value in zip(self.coltypes, row)]) for row in result]
930
931    def callproc(self, procname, parameters=None):
932        """Call a stored database procedure with the given name.
933
934        The sequence of parameters must contain one entry for each input
935        argument that the procedure expects. The result of the call is the
936        same as this input sequence; replacement of output and input/output
937        parameters in the return value is currently not supported.
938
939        The procedure may also provide a result set as output. These can be
940        requested through the standard fetch methods of the cursor.
941        """
942        n = parameters and len(parameters) or 0
943        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
944        self.execute(query, parameters)
945        return parameters
946
947    def copy_from(self, stream, table,
948            format=None, sep=None, null=None, size=None, columns=None):
949        """Copy data from an input stream to the specified table.
950
951        The input stream can be a file-like object with a read() method or
952        it can also be an iterable returning a row or multiple rows of input
953        on each iteration.
954
955        The format must be text, csv or binary. The sep option sets the
956        column separator (delimiter) used in the non binary formats.
957        The null option sets the textual representation of NULL in the input.
958
959        The size option sets the size of the buffer used when reading data
960        from file-like objects.
961
962        The copy operation can be restricted to a subset of columns. If no
963        columns are specified, all of them will be copied.
964        """
965        binary_format = format == 'binary'
966        try:
967            read = stream.read
968        except AttributeError:
969            if size:
970                raise ValueError("Size must only be set for file-like objects")
971            if binary_format:
972                input_type = bytes
973                type_name = 'byte strings'
974            else:
975                input_type = basestring
976                type_name = 'strings'
977
978            if isinstance(stream, basestring):
979                if not isinstance(stream, input_type):
980                    raise ValueError("The input must be %s" % type_name)
981                if not binary_format:
982                    if isinstance(stream, str):
983                        if not stream.endswith('\n'):
984                            stream += '\n'
985                    else:
986                        if not stream.endswith(b'\n'):
987                            stream += b'\n'
988
989                def chunks():
990                    yield stream
991
992            elif isinstance(stream, Iterable):
993
994                def chunks():
995                    for chunk in stream:
996                        if not isinstance(chunk, input_type):
997                            raise ValueError(
998                                "Input stream must consist of %s" % type_name)
999                        if isinstance(chunk, str):
1000                            if not chunk.endswith('\n'):
1001                                chunk += '\n'
1002                        else:
1003                            if not chunk.endswith(b'\n'):
1004                                chunk += b'\n'
1005                        yield chunk
1006
1007            else:
1008                raise TypeError("Need an input stream to copy from")
1009        else:
1010            if size is None:
1011                size = 8192
1012            elif not isinstance(size, int):
1013                raise TypeError("The size option must be an integer")
1014            if size > 0:
1015
1016                def chunks():
1017                    while True:
1018                        buffer = read(size)
1019                        yield buffer
1020                        if not buffer or len(buffer) < size:
1021                            break
1022
1023            else:
1024
1025                def chunks():
1026                    yield read()
1027
1028        if not table or not isinstance(table, basestring):
1029            raise TypeError("Need a table to copy to")
1030        if table.lower().startswith('select'):
1031                raise ValueError("Must specify a table, not a query")
1032        else:
1033            table = '"%s"' % (table,)
1034        operation = ['copy %s' % (table,)]
1035        options = []
1036        params = []
1037        if format is not None:
1038            if not isinstance(format, basestring):
1039                raise TypeError("The frmat option must be be a string")
1040            if format not in ('text', 'csv', 'binary'):
1041                raise ValueError("Invalid format")
1042            options.append('format %s' % (format,))
1043        if sep is not None:
1044            if not isinstance(sep, basestring):
1045                raise TypeError("The sep option must be a string")
1046            if format == 'binary':
1047                raise ValueError(
1048                    "The sep option is not allowed with binary format")
1049            if len(sep) != 1:
1050                raise ValueError(
1051                    "The sep option must be a single one-byte character")
1052            options.append('delimiter %s')
1053            params.append(sep)
1054        if null is not None:
1055            if not isinstance(null, basestring):
1056                raise TypeError("The null option must be a string")
1057            options.append('null %s')
1058            params.append(null)
1059        if columns:
1060            if not isinstance(columns, basestring):
1061                columns = ','.join('"%s"' % (col,) for col in columns)
1062            operation.append('(%s)' % (columns,))
1063        operation.append("from stdin")
1064        if options:
1065            operation.append('(%s)' % ','.join(options))
1066        operation = ' '.join(operation)
1067
1068        putdata = self._src.putdata
1069        self.execute(operation, params)
1070
1071        try:
1072            for chunk in chunks():
1073                putdata(chunk)
1074        except BaseException as error:
1075            self.rowcount = -1
1076            # the following call will re-raise the error
1077            putdata(error)
1078        else:
1079            self.rowcount = putdata(None)
1080
1081        # return the cursor object, so you can chain operations
1082        return self
1083
1084    def copy_to(self, stream, table,
1085            format=None, sep=None, null=None, decode=None, columns=None):
1086        """Copy data from the specified table to an output stream.
1087
1088        The output stream can be a file-like object with a write() method or
1089        it can also be None, in which case the method will return a generator
1090        yielding a row on each iteration.
1091
1092        Output will be returned as byte strings unless you set decode to true.
1093
1094        Note that you can also use a select query instead of the table name.
1095
1096        The format must be text, csv or binary. The sep option sets the
1097        column separator (delimiter) used in the non binary formats.
1098        The null option sets the textual representation of NULL in the output.
1099
1100        The copy operation can be restricted to a subset of columns. If no
1101        columns are specified, all of them will be copied.
1102        """
1103        binary_format = format == 'binary'
1104        if stream is not None:
1105            try:
1106                write = stream.write
1107            except AttributeError:
1108                raise TypeError("Need an output stream to copy to")
1109        if not table or not isinstance(table, basestring):
1110            raise TypeError("Need a table to copy to")
1111        if table.lower().startswith('select'):
1112            if columns:
1113                raise ValueError("Columns must be specified in the query")
1114            table = '(%s)' % (table,)
1115        else:
1116            table = '"%s"' % (table,)
1117        operation = ['copy %s' % (table,)]
1118        options = []
1119        params = []
1120        if format is not None:
1121            if not isinstance(format, basestring):
1122                raise TypeError("The format option must be a string")
1123            if format not in ('text', 'csv', 'binary'):
1124                raise ValueError("Invalid format")
1125            options.append('format %s' % (format,))
1126        if sep is not None:
1127            if not isinstance(sep, basestring):
1128                raise TypeError("The sep option must be a string")
1129            if binary_format:
1130                raise ValueError(
1131                    "The sep option is not allowed with binary format")
1132            if len(sep) != 1:
1133                raise ValueError(
1134                    "The sep option must be a single one-byte character")
1135            options.append('delimiter %s')
1136            params.append(sep)
1137        if null is not None:
1138            if not isinstance(null, basestring):
1139                raise TypeError("The null option must be a string")
1140            options.append('null %s')
1141            params.append(null)
1142        if decode is None:
1143            if format == 'binary':
1144                decode = False
1145            else:
1146                decode = str is unicode
1147        else:
1148            if not isinstance(decode, (int, bool)):
1149                raise TypeError("The decode option must be a boolean")
1150            if decode and binary_format:
1151                raise ValueError(
1152                    "The decode option is not allowed with binary format")
1153        if columns:
1154            if not isinstance(columns, basestring):
1155                columns = ','.join('"%s"' % (col,) for col in columns)
1156            operation.append('(%s)' % (columns,))
1157
1158        operation.append("to stdout")
1159        if options:
1160            operation.append('(%s)' % ','.join(options))
1161        operation = ' '.join(operation)
1162
1163        getdata = self._src.getdata
1164        self.execute(operation, params)
1165
1166        def copy():
1167            self.rowcount = 0
1168            while True:
1169                row = getdata(decode)
1170                if isinstance(row, int):
1171                    if self.rowcount != row:
1172                        self.rowcount = row
1173                    break
1174                self.rowcount += 1
1175                yield row
1176
1177        if stream is None:
1178            # no input stream, return the generator
1179            return copy()
1180
1181        # write the rows to the file-like input stream
1182        for row in copy():
1183            write(row)
1184
1185        # return the cursor object, so you can chain operations
1186        return self
1187
1188    def __next__(self):
1189        """Return the next row (support for the iteration protocol)."""
1190        res = self.fetchone()
1191        if res is None:
1192            raise StopIteration
1193        return res
1194
1195    # Note that since Python 2.6 the iterator protocol uses __next()__
1196    # instead of next(), we keep it only for backward compatibility of pgdb.
1197    next = __next__
1198
1199    @staticmethod
1200    def nextset():
1201        """Not supported."""
1202        raise NotSupportedError("The nextset() method is not supported")
1203
1204    @staticmethod
1205    def setinputsizes(sizes):
1206        """Not supported."""
1207        pass  # unsupported, but silently passed
1208
1209    @staticmethod
1210    def setoutputsize(size, column=0):
1211        """Not supported."""
1212        pass  # unsupported, but silently passed
1213
1214    @staticmethod
1215    def row_factory(row):
1216        """Process rows before they are returned.
1217
1218        You can overwrite this statically with a custom row factory, or
1219        you can build a row factory dynamically with build_row_factory().
1220
1221        For example, you can create a Cursor class that returns rows as
1222        Python dictionaries like this:
1223
1224            class DictCursor(pgdb.Cursor):
1225
1226                def row_factory(self, row):
1227                    return {desc[0]: value
1228                        for desc, value in zip(self.description, row)}
1229
1230            cur = DictCursor(con)  # get one DictCursor instance or
1231            con.cursor_type = DictCursor  # always use DictCursor instances
1232        """
1233        raise NotImplementedError
1234
1235    def build_row_factory(self):
1236        """Build a row factory based on the current description.
1237
1238        This implementation builds a row factory for creating named tuples.
1239        You can overwrite this method if you want to dynamically create
1240        different row factories whenever the column description changes.
1241        """
1242        colnames = self.colnames
1243        if colnames:
1244            try:
1245                try:
1246                    return namedtuple('Row', colnames, rename=True)._make
1247                except TypeError:  # Python 2.6 and 3.0 do not support rename
1248                    colnames = [v if v.isalnum() else 'column_%d' % n
1249                             for n, v in enumerate(colnames)]
1250                    return namedtuple('Row', colnames)._make
1251            except ValueError:  # there is still a problem with the field names
1252                colnames = ['column_%d' % n for n in range(len(colnames))]
1253                return namedtuple('Row', colnames)._make
1254
1255
1256CursorDescription = namedtuple('CursorDescription',
1257    ['name', 'type_code', 'display_size', 'internal_size',
1258     'precision', 'scale', 'null_ok'])
1259
1260
1261### Connection Objects
1262
1263class Connection(object):
1264    """Connection object."""
1265
1266    # expose the exceptions as attributes on the connection object
1267    Error = Error
1268    Warning = Warning
1269    InterfaceError = InterfaceError
1270    DatabaseError = DatabaseError
1271    InternalError = InternalError
1272    OperationalError = OperationalError
1273    ProgrammingError = ProgrammingError
1274    IntegrityError = IntegrityError
1275    DataError = DataError
1276    NotSupportedError = NotSupportedError
1277
1278    def __init__(self, cnx):
1279        """Create a database connection object."""
1280        self._cnx = cnx  # connection
1281        self._tnx = False  # transaction state
1282        self.type_cache = TypeCache(cnx)
1283        self.cursor_type = Cursor
1284        try:
1285            self._cnx.source()
1286        except Exception:
1287            raise _op_error("Invalid connection")
1288
1289    def __enter__(self):
1290        """Enter the runtime context for the connection object.
1291
1292        The runtime context can be used for running transactions.
1293        """
1294        return self
1295
1296    def __exit__(self, et, ev, tb):
1297        """Exit the runtime context for the connection object.
1298
1299        This does not close the connection, but it ends a transaction.
1300        """
1301        if et is None and ev is None and tb is None:
1302            self.commit()
1303        else:
1304            self.rollback()
1305
1306    def close(self):
1307        """Close the connection object."""
1308        if self._cnx:
1309            if self._tnx:
1310                try:
1311                    self.rollback()
1312                except DatabaseError:
1313                    pass
1314            self._cnx.close()
1315            self._cnx = None
1316        else:
1317            raise _op_error("Connection has been closed")
1318
1319    def commit(self):
1320        """Commit any pending transaction to the database."""
1321        if self._cnx:
1322            if self._tnx:
1323                self._tnx = False
1324                try:
1325                    self._cnx.source().execute("COMMIT")
1326                except DatabaseError:
1327                    raise
1328                except Exception:
1329                    raise _op_error("Can't commit")
1330        else:
1331            raise _op_error("Connection has been closed")
1332
1333    def rollback(self):
1334        """Roll back to the start of any pending transaction."""
1335        if self._cnx:
1336            if self._tnx:
1337                self._tnx = False
1338                try:
1339                    self._cnx.source().execute("ROLLBACK")
1340                except DatabaseError:
1341                    raise
1342                except Exception:
1343                    raise _op_error("Can't rollback")
1344        else:
1345            raise _op_error("Connection has been closed")
1346
1347    def cursor(self):
1348        """Return a new cursor object using the connection."""
1349        if self._cnx:
1350            try:
1351                return self.cursor_type(self)
1352            except Exception:
1353                raise _op_error("Invalid connection")
1354        else:
1355            raise _op_error("Connection has been closed")
1356
1357    if shortcutmethods:  # otherwise do not implement and document this
1358
1359        def execute(self, operation, params=None):
1360            """Shortcut method to run an operation on an implicit cursor."""
1361            cursor = self.cursor()
1362            cursor.execute(operation, params)
1363            return cursor
1364
1365        def executemany(self, operation, param_seq):
1366            """Shortcut method to run an operation against a sequence."""
1367            cursor = self.cursor()
1368            cursor.executemany(operation, param_seq)
1369            return cursor
1370
1371
1372### Module Interface
1373
1374_connect_ = connect
1375
1376def connect(dsn=None,
1377        user=None, password=None,
1378        host=None, database=None):
1379    """Connect to a database."""
1380    # first get params from DSN
1381    dbport = -1
1382    dbhost = ""
1383    dbbase = ""
1384    dbuser = ""
1385    dbpasswd = ""
1386    dbopt = ""
1387    try:
1388        params = dsn.split(":")
1389        dbhost = params[0]
1390        dbbase = params[1]
1391        dbuser = params[2]
1392        dbpasswd = params[3]
1393        dbopt = params[4]
1394    except (AttributeError, IndexError, TypeError):
1395        pass
1396
1397    # override if necessary
1398    if user is not None:
1399        dbuser = user
1400    if password is not None:
1401        dbpasswd = password
1402    if database is not None:
1403        dbbase = database
1404    if host is not None:
1405        try:
1406            params = host.split(":")
1407            dbhost = params[0]
1408            dbport = int(params[1])
1409        except (AttributeError, IndexError, TypeError, ValueError):
1410            pass
1411
1412    # empty host is localhost
1413    if dbhost == "":
1414        dbhost = None
1415    if dbuser == "":
1416        dbuser = None
1417
1418    # open the connection
1419    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
1420    return Connection(cnx)
1421
1422
1423### Types Handling
1424
1425class Type(frozenset):
1426    """Type class for a couple of PostgreSQL data types.
1427
1428    PostgreSQL is object-oriented: types are dynamic.
1429    We must thus use type names as internal type codes.
1430    """
1431
1432    def __new__(cls, values):
1433        if isinstance(values, basestring):
1434            values = values.split()
1435        return super(Type, cls).__new__(cls, values)
1436
1437    def __eq__(self, other):
1438        if isinstance(other, basestring):
1439            if other.startswith('_'):
1440                other = other[1:]
1441            return other in self
1442        else:
1443            return super(Type, self).__eq__(other)
1444
1445    def __ne__(self, other):
1446        if isinstance(other, basestring):
1447            if other.startswith('_'):
1448                other = other[1:]
1449            return other not in self
1450        else:
1451            return super(Type, self).__ne__(other)
1452
1453
1454class ArrayType:
1455    """Type class for PostgreSQL array types."""
1456
1457    def __eq__(self, other):
1458        if isinstance(other, basestring):
1459            return other.startswith('_')
1460        else:
1461            return isinstance(other, ArrayType)
1462
1463    def __ne__(self, other):
1464        if isinstance(other, basestring):
1465            return not other.startswith('_')
1466        else:
1467            return not isinstance(other, ArrayType)
1468
1469
1470class RecordType:
1471    """Type class for PostgreSQL record types."""
1472
1473    def __eq__(self, other):
1474        if isinstance(other, TypeCode):
1475            return other.type == 'c'
1476        elif isinstance(other, basestring):
1477            return other == 'record'
1478        else:
1479            return isinstance(other, RecordType)
1480
1481    def __ne__(self, other):
1482        if isinstance(other, TypeCode):
1483            return other.type != 'c'
1484        elif isinstance(other, basestring):
1485            return other != 'record'
1486        else:
1487            return not isinstance(other, RecordType)
1488
1489
1490# Mandatory type objects defined by DB-API 2 specs:
1491
1492STRING = Type('char bpchar name text varchar')
1493BINARY = Type('bytea')
1494NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1495DATETIME = Type('date time timetz timestamp timestamptz interval'
1496    ' abstime reltime')  # these are very old
1497ROWID = Type('oid')
1498
1499
1500# Additional type objects (more specific):
1501
1502BOOL = Type('bool')
1503SMALLINT = Type('int2')
1504INTEGER = Type('int2 int4 int8 serial')
1505LONG = Type('int8')
1506FLOAT = Type('float4 float8')
1507NUMERIC = Type('numeric')
1508MONEY = Type('money')
1509DATE = Type('date')
1510TIME = Type('time timetz')
1511TIMESTAMP = Type('timestamp timestamptz')
1512INTERVAL = Type('interval')
1513JSON = Type('json jsonb')
1514
1515# Type object for arrays (also equate to their base types):
1516
1517ARRAY = ArrayType()
1518
1519# Type object for records (encompassing all composite types):
1520
1521RECORD = RecordType()
1522
1523
1524# Mandatory type helpers defined by DB-API 2 specs:
1525
1526def Date(year, month, day):
1527    """Construct an object holding a date value."""
1528    return date(year, month, day)
1529
1530
1531def Time(hour, minute=0, second=0, microsecond=0):
1532    """Construct an object holding a time value."""
1533    return time(hour, minute, second, microsecond)
1534
1535
1536def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1537    """Construct an object holding a time stamp value."""
1538    return datetime(year, month, day, hour, minute, second, microsecond)
1539
1540
1541def DateFromTicks(ticks):
1542    """Construct an object holding a date value from the given ticks value."""
1543    return Date(*localtime(ticks)[:3])
1544
1545
1546def TimeFromTicks(ticks):
1547    """Construct an object holding a time value from the given ticks value."""
1548    return Time(*localtime(ticks)[3:6])
1549
1550
1551def TimestampFromTicks(ticks):
1552    """Construct an object holding a time stamp from the given ticks value."""
1553    return Timestamp(*localtime(ticks)[:6])
1554
1555
1556class Binary(bytes):
1557    """Construct an object capable of holding a binary (long) string value."""
1558
1559
1560# Additional type helpers for PyGreSQL:
1561
1562def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0):
1563    """Construct an object holding a time inverval value."""
1564    return timedelta(days, hours=hours, minutes=minutes, seconds=seconds,
1565        microseconds=microseconds)
1566
1567class Bytea(bytes):
1568    """Construct an object capable of holding a bytea value."""
1569
1570
1571class Json:
1572    """Construct a wrapper for holding an object serializable to JSON."""
1573
1574    def __init__(self, obj, encode=None):
1575        self.obj = obj
1576        self.encode = encode or jsonencode
1577
1578    def __str__(self):
1579        obj = self.obj
1580        if isinstance(obj, basestring):
1581            return obj
1582        return self.encode(obj)
1583
1584    __pg_repr__ = __str__
1585
1586
1587class Literal:
1588    """Construct a wrapper for holding a literal SQL string."""
1589
1590    def __init__(self, sql):
1591        self.sql = sql
1592
1593    def __str__(self):
1594        return self.sql
1595
1596    __pg_repr__ = __str__
1597
1598
1599# If run as script, print some information:
1600
1601if __name__ == '__main__':
1602    print('PyGreSQL version', version)
1603    print('')
1604    print(__doc__)
Note: See TracBrowser for help on using the repository browser.