source: trunk/pgdb.py @ 819

Last change on this file since 819 was 819, checked in by cito, 4 years ago

Add type object for the hstore type

  • Property svn:keywords set to Id
File size: 52.8 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 819 2016-02-05 12:43:28Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from functools import partial
76from re import compile as regex
77from json import loads as jsondecode, dumps as jsonencode
78
79try:
80    long
81except NameError:  # Python >= 3.0
82    long = int
83
84try:
85    unicode
86except NameError:  # Python >= 3.0
87    unicode = str
88
89try:
90    basestring
91except NameError:  # Python >= 3.0
92    basestring = (str, bytes)
93
94from collections import Iterable
95
96
97### Module Constants
98
99# compliant with DB API 2.0
100apilevel = '2.0'
101
102# module may be shared, but not connections
103threadsafety = 1
104
105# this module use extended python format codes
106paramstyle = 'pyformat'
107
108# shortcut methods have been excluded from DB API 2 and
109# are not recommended by the DB SIG, but they can be handy
110shortcutmethods = 1
111
112
113### Internal Type Handling
114
115try:
116    from inspect import signature
117except ImportError:  # Python < 3.3
118    from inspect import getargspec
119
120    def get_args(func):
121        return getargspec(func).args
122else:
123
124    def get_args(func):
125        return list(signature(func).parameters)
126
127try:
128    if datetime.strptime('+0100', '%z') is None:
129        raise ValueError
130except ValueError:  # Python < 3.2
131    timezones = None
132else:
133    # time zones used in Postgres timestamptz output
134    timezones = dict(CET='+0100', EET='+0200', EST='-0500',
135        GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
136        UCT='+0000', UTC='+0000', WET='+0000')
137
138
139def decimal_type(decimal_type=None):
140    """Get or set global type to be used for decimal values.
141
142    Note that connections cache cast functions. To be sure a global change
143    is picked up by a running connection, call con.type_cache.reset_typecast().
144    """
145    global Decimal
146    if decimal_type is not None:
147        Decimal = decimal_type
148        set_typecast('numeric', decimal_type)
149    return Decimal
150
151
152def cast_bool(value):
153    """Cast boolean value in database format to bool."""
154    if value:
155        return value[0] in ('t', 'T')
156
157
158def cast_money(value):
159    """Cast money value in database format to Decimal."""
160    if value:
161        value = value.replace('(', '-')
162        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
163
164
165def cast_int2vector(value):
166    """Cast an int2vector value."""
167    return [int(v) for v in value.split()]
168
169
170def cast_date(value, connection):
171    """Cast a date value."""
172    # The output format depends on the server setting DateStyle.  The default
173    # setting ISO and the setting for German are actually unambiguous.  The
174    # order of days and months in the other two settings is however ambiguous,
175    # so at least here we need to consult the setting to properly parse values.
176    if value == '-infinity':
177        return date.min
178    if value == 'infinity':
179        return date.max
180    value = value.split()
181    if value[-1] == 'BC':
182        return date.min
183    value = value[0]
184    if len(value) > 10:
185        return date.max
186    fmt = connection.date_format()
187    return datetime.strptime(value, fmt).date()
188
189
190def cast_time(value):
191    """Cast a time value."""
192    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
193    return datetime.strptime(value, fmt).time()
194
195
196_re_timezone = regex('(.*)([+-].*)')
197
198
199def cast_timetz(value):
200    """Cast a timetz value."""
201    tz = _re_timezone.match(value)
202    if tz:
203        value, tz = tz.groups()
204    else:
205        tz = '+0000'
206    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
207    if timezones:
208        if tz.startswith(('+', '-')):
209            if len(tz) < 5:
210                tz += '00'
211            else:
212                tz = tz.replace(':', '')
213        elif tz in timezones:
214            tz = timezones[tz]
215        else:
216            tz = '+0000'
217        value += tz
218        fmt += '%z'
219    return datetime.strptime(value, fmt).timetz()
220
221
222def cast_timestamp(value, connection):
223    """Cast a timestamp value."""
224    if value == '-infinity':
225        return datetime.min
226    if value == 'infinity':
227        return datetime.max
228    value = value.split()
229    if value[-1] == 'BC':
230        return datetime.min
231    fmt = connection.date_format()
232    if fmt.endswith('-%Y') and len(value) > 2:
233        value = value[1:5]
234        if len(value[3]) > 4:
235            return datetime.max
236        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
237            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
238    else:
239        if len(value[0]) > 10:
240            return datetime.max
241        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
242    return datetime.strptime(' '.join(value), ' '.join(fmt))
243
244
245def cast_timestamptz(value, connection):
246    """Cast a timestamptz value."""
247    if value == '-infinity':
248        return datetime.min
249    if value == 'infinity':
250        return datetime.max
251    value = value.split()
252    if value[-1] == 'BC':
253        return datetime.min
254    fmt = connection.date_format()
255    if fmt.endswith('-%Y') and len(value) > 2:
256        value = value[1:]
257        if len(value[3]) > 4:
258            return datetime.max
259        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
260            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
261        value, tz = value[:-1], value[-1]
262    else:
263        if fmt.startswith('%Y-'):
264            tz = _re_timezone.match(value[1])
265            if tz:
266                value[1], tz = tz.groups()
267            else:
268                tz = '+0000'
269        else:
270            value, tz = value[:-1], value[-1]
271        if len(value[0]) > 10:
272            return datetime.max
273        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
274    if timezones:
275        if tz.startswith(('+', '-')):
276            if len(tz) < 5:
277                tz += '00'
278            else:
279                tz = tz.replace(':', '')
280        elif tz in timezones:
281            tz = timezones[tz]
282        else:
283            tz = '+0000'
284        value.append(tz)
285        fmt.append('%z')
286    return datetime.strptime(' '.join(value), ' '.join(fmt))
287
288_re_interval_sql_standard = regex(
289    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
290    '(?:([+-]?[0-9]+)(?!:) ?)?'
291    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
292
293_re_interval_postgres = regex(
294    '(?:([+-]?[0-9]+) ?years? ?)?'
295    '(?:([+-]?[0-9]+) ?mons? ?)?'
296    '(?:([+-]?[0-9]+) ?days? ?)?'
297    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
298
299_re_interval_postgres_verbose = regex(
300    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
301    '(?:([+-]?[0-9]+) ?mons? ?)?'
302    '(?:([+-]?[0-9]+) ?days? ?)?'
303    '(?:([+-]?[0-9]+) ?hours? ?)?'
304    '(?:([+-]?[0-9]+) ?mins? ?)?'
305    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
306
307_re_interval_iso_8601 = regex(
308    'P(?:([+-]?[0-9]+)Y)?'
309    '(?:([+-]?[0-9]+)M)?'
310    '(?:([+-]?[0-9]+)D)?'
311    '(?:T(?:([+-]?[0-9]+)H)?'
312    '(?:([+-]?[0-9]+)M)?'
313    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
314
315
316def cast_interval(value):
317    """Cast an interval value."""
318    # The output format depends on the server setting IntervalStyle, but it's
319    # not necessary to consult this setting to parse it.  It's faster to just
320    # check all possible formats, and there is no ambiguity here.
321    m = _re_interval_iso_8601.match(value)
322    if m:
323        m = [d or '0' for d in m.groups()]
324        secs_ago = m.pop(5) == '-'
325        m = [int(d) for d in m]
326        years, mons, days, hours, mins, secs, usecs = m
327        if secs_ago:
328            secs = -secs
329            usecs = -usecs
330    else:
331        m = _re_interval_postgres_verbose.match(value)
332        if m:
333            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
334            secs_ago = m.pop(5) == '-'
335            m = [-int(d) for d in m] if ago else [int(d) for d in m]
336            years, mons, days, hours, mins, secs, usecs = m
337            if secs_ago:
338                secs = - secs
339                usecs = -usecs
340        else:
341            m = _re_interval_postgres.match(value)
342            if m and any(m.groups()):
343                m = [d or '0' for d in m.groups()]
344                hours_ago = m.pop(3) == '-'
345                m = [int(d) for d in m]
346                years, mons, days, hours, mins, secs, usecs = m
347                if hours_ago:
348                    hours = -hours
349                    mins = -mins
350                    secs = -secs
351                    usecs = -usecs
352            else:
353                m = _re_interval_sql_standard.match(value)
354                if m and any(m.groups()):
355                    m = [d or '0' for d in m.groups()]
356                    years_ago = m.pop(0) == '-'
357                    hours_ago = m.pop(3) == '-'
358                    m = [int(d) for d in m]
359                    years, mons, days, hours, mins, secs, usecs = m
360                    if years_ago:
361                        years = -years
362                        mons = -mons
363                    if hours_ago:
364                        hours = -hours
365                        mins = -mins
366                        secs = -secs
367                        usecs = -usecs
368                else:
369                    raise ValueError('Cannot parse interval: %s' % value)
370    days += 365 * years + 30 * mons
371    return timedelta(days=days, hours=hours, minutes=mins,
372        seconds=secs, microseconds=usecs)
373
374
375class Typecasts(dict):
376    """Dictionary mapping database types to typecast functions.
377
378    The cast functions get passed the string representation of a value in
379    the database which they need to convert to a Python object.  The
380    passed string will never be None since NULL values are already be
381    handled before the cast function is called.
382    """
383
384    # the default cast functions
385    # (str functions are ignored but have been added for faster access)
386    defaults = {'char': str, 'bpchar': str, 'name': str,
387        'text': str, 'varchar': str,
388        'bool': cast_bool, 'bytea': unescape_bytea,
389        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
390        'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode,
391        'float4': float, 'float8': float,
392        'numeric': Decimal, 'money': cast_money,
393        'date': cast_date, 'interval': cast_interval,
394        'time': cast_time, 'timetz': cast_timetz,
395        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
396        'int2vector': cast_int2vector,
397        'anyarray': cast_array, 'record': cast_record}
398
399    connection = None  # will be set in local connection specific instances
400
401    def __missing__(self, typ):
402        """Create a cast function if it is not cached.
403
404        Note that this class never raises a KeyError,
405        but returns None when no special cast function exists.
406        """
407        if not isinstance(typ, str):
408            raise TypeError('Invalid type: %s' % typ)
409        cast = self.defaults.get(typ)
410        if cast:
411            # store default for faster access
412            cast = self._add_connection(cast)
413            self[typ] = cast
414        elif typ.startswith('_'):
415            # create array cast
416            base_cast = self[typ[1:]]
417            cast = self.create_array_cast(base_cast)
418            if base_cast:
419                # store only if base type exists
420                self[typ] = cast
421        return cast
422
423    @staticmethod
424    def _needs_connection(func):
425        """Check if a typecast function needs a connection argument."""
426        try:
427            args = get_args(func)
428        except (TypeError, ValueError):
429            return False
430        else:
431            return 'connection' in args[1:]
432
433    def _add_connection(self, cast):
434        """Add a connection argument to the typecast function if necessary."""
435        if not self.connection or not self._needs_connection(cast):
436            return cast
437        return partial(cast, connection=self.connection)
438
439    def get(self, typ, default=None):
440        """Get the typecast function for the given database type."""
441        return self[typ] or default
442
443    def set(self, typ, cast):
444        """Set a typecast function for the specified database type(s)."""
445        if isinstance(typ, basestring):
446            typ = [typ]
447        if cast is None:
448            for t in typ:
449                self.pop(t, None)
450                self.pop('_%s' % t, None)
451        else:
452            if not callable(cast):
453                raise TypeError("Cast parameter must be callable")
454            for t in typ:
455                self[t] = self._add_connection(cast)
456                self.pop('_%s' % t, None)
457
458    def reset(self, typ=None):
459        """Reset the typecasts for the specified type(s) to their defaults.
460
461        When no type is specified, all typecasts will be reset.
462        """
463        defaults = self.defaults
464        if typ is None:
465            self.clear()
466            self.update(defaults)
467        else:
468            if isinstance(typ, basestring):
469                typ = [typ]
470            for t in typ:
471                cast = defaults.get(t)
472                if cast:
473                    self[t] = self._add_connection(cast)
474                    t = '_%s' % t
475                    cast = defaults.get(t)
476                    if cast:
477                        self[t] = self._add_connection(cast)
478                    else:
479                        self.pop(t, None)
480                else:
481                    self.pop(t, None)
482                    self.pop('_%s' % t, None)
483
484    def create_array_cast(self, basecast):
485        """Create an array typecast for the given base cast."""
486        def cast(v):
487            return cast_array(v, basecast)
488        return cast
489
490    def create_record_cast(self, name, fields, casts):
491        """Create a named record typecast for the given fields and casts."""
492        record = namedtuple(name, fields)
493        def cast(v):
494            return record(*cast_record(v, casts))
495        return cast
496
497
498_typecasts = Typecasts()  # this is the global typecast dictionary
499
500
501def get_typecast(typ):
502    """Get the global typecast function for the given database type(s)."""
503    return _typecasts.get(typ)
504
505
506def set_typecast(typ, cast):
507    """Set a global typecast function for the given database type(s).
508
509    Note that connections cache cast functions. To be sure a global change
510    is picked up by a running connection, call con.type_cache.reset_typecast().
511    """
512    _typecasts.set(typ, cast)
513
514
515def reset_typecast(typ=None):
516    """Reset the global typecasts for the given type(s) to their default.
517
518    When no type is specified, all typecasts will be reset.
519
520    Note that connections cache cast functions. To be sure a global change
521    is picked up by a running connection, call con.type_cache.reset_typecast().
522    """
523    _typecasts.reset(typ)
524
525
526class LocalTypecasts(Typecasts):
527    """Map typecasts, including local composite types, to cast functions."""
528
529    defaults = _typecasts
530
531    connection = None  # will be set in a connection specific instance
532
533    def __missing__(self, typ):
534        """Create a cast function if it is not cached."""
535        if typ.startswith('_'):
536            base_cast = self[typ[1:]]
537            cast = self.create_array_cast(base_cast)
538            if base_cast:
539                self[typ] = cast
540        else:
541            cast = self.defaults.get(typ)
542            if cast:
543                cast = self._add_connection(cast)
544                self[typ] = cast
545            else:
546                fields = self.get_fields(typ)
547                if fields:
548                    casts = [self[field.type] for field in fields]
549                    fields = [field.name for field in fields]
550                    cast = self.create_record_cast(typ, fields, casts)
551                    self[typ] = cast
552        return cast
553
554    def get_fields(self, typ):
555        """Return the fields for the given record type.
556
557        This method will be replaced with a method that looks up the fields
558        using the type cache of the connection.
559        """
560        return []
561
562
563class TypeCode(str):
564    """Class representing the type_code used by the DB-API 2.0.
565
566    TypeCode objects are strings equal to the PostgreSQL type name,
567    but carry some additional information.
568    """
569
570    @classmethod
571    def create(cls, oid, name, len, type, category, delim, relid):
572        """Create a type code for a PostgreSQL data type."""
573        self = cls(name)
574        self.oid = oid
575        self.len = len
576        self.type = type
577        self.category = category
578        self.delim = delim
579        self.relid = relid
580        return self
581
582FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
583
584
585class TypeCache(dict):
586    """Cache for database types.
587
588    This cache maps type OIDs and names to TypeCode strings containing
589    important information on the associated database type.
590    """
591
592    def __init__(self, cnx):
593        """Initialize type cache for connection."""
594        super(TypeCache, self).__init__()
595        self._escape_string = cnx.escape_string
596        self._src = cnx.source()
597        self._typecasts = LocalTypecasts()
598        self._typecasts.get_fields = self.get_fields
599        self._typecasts.connection = cnx
600
601    def __missing__(self, key):
602        """Get the type info from the database if it is not cached."""
603        if isinstance(key, int):
604            oid = key
605        else:
606            if '.' not in key and '"' not in key:
607                key = '"%s"' % key
608            oid = "'%s'::regtype" % self._escape_string(key)
609        try:
610            self._src.execute("SELECT oid, typname,"
611                 " typlen, typtype, typcategory, typdelim, typrelid"
612                " FROM pg_type WHERE oid=%s" % oid)
613        except ProgrammingError:
614            res = None
615        else:
616            res = self._src.fetch(1)
617        if not res:
618            raise KeyError('Type %s could not be found' % key)
619        res = res[0]
620        type_code = TypeCode.create(int(res[0]), res[1],
621            int(res[2]), res[3], res[4], res[5], int(res[6]))
622        self[type_code.oid] = self[str(type_code)] = type_code
623        return type_code
624
625    def get(self, key, default=None):
626        """Get the type even if it is not cached."""
627        try:
628            return self[key]
629        except KeyError:
630            return default
631
632    def get_fields(self, typ):
633        """Get the names and types of the fields of composite types."""
634        if not isinstance(typ, TypeCode):
635            typ = self.get(typ)
636            if not typ:
637                return None
638        if not typ.relid:
639            return None  # this type is not composite
640        self._src.execute("SELECT attname, atttypid"
641            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
642            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
643        return [FieldInfo(name, self.get(int(oid)))
644            for name, oid in self._src.fetch(-1)]
645
646    def get_typecast(self, typ):
647        """Get the typecast function for the given database type."""
648        return self._typecasts.get(typ)
649
650    def set_typecast(self, typ, cast):
651        """Set a typecast function for the specified database type(s)."""
652        self._typecasts.set(typ, cast)
653
654    def reset_typecast(self, typ=None):
655        """Reset the typecast function for the specified database type(s)."""
656        self._typecasts.reset(typ)
657
658    def typecast(self, value, typ):
659        """Cast the given value according to the given database type."""
660        if value is None:
661            # for NULL values, no typecast is necessary
662            return None
663        cast = self.get_typecast(typ)
664        if not cast or cast is str:
665            # no typecast is necessary
666            return value
667        return cast(value)
668
669
670class _quotedict(dict):
671    """Dictionary with auto quoting of its items.
672
673    The quote attribute must be set to the desired quote function.
674    """
675
676    def __getitem__(self, key):
677        return self.quote(super(_quotedict, self).__getitem__(key))
678
679
680### Error messages
681
682def _db_error(msg, cls=DatabaseError):
683    """Return DatabaseError with empty sqlstate attribute."""
684    error = cls(msg)
685    error.sqlstate = None
686    return error
687
688
689def _op_error(msg):
690    """Return OperationalError."""
691    return _db_error(msg, OperationalError)
692
693
694### Cursor Object
695
696class Cursor(object):
697    """Cursor object."""
698
699    def __init__(self, dbcnx):
700        """Create a cursor object for the database connection."""
701        self.connection = self._dbcnx = dbcnx
702        self._cnx = dbcnx._cnx
703        self.type_cache = dbcnx.type_cache
704        self._src = self._cnx.source()
705        # the official attribute for describing the result columns
706        self._description = None
707        if self.row_factory is Cursor.row_factory:
708            # the row factory needs to be determined dynamically
709            self.row_factory = None
710        else:
711            self.build_row_factory = None
712        self.rowcount = -1
713        self.arraysize = 1
714        self.lastrowid = None
715
716    def __iter__(self):
717        """Make cursor compatible to the iteration protocol."""
718        return self
719
720    def __enter__(self):
721        """Enter the runtime context for the cursor object."""
722        return self
723
724    def __exit__(self, et, ev, tb):
725        """Exit the runtime context for the cursor object."""
726        self.close()
727
728    def _quote(self, value):
729        """Quote value depending on its type."""
730        if value is None:
731            return 'NULL'
732        if isinstance(value, (datetime, date, time, timedelta, Hstore, Json)):
733            value = str(value)
734        if isinstance(value, basestring):
735            if isinstance(value, Binary):
736                value = self._cnx.escape_bytea(value)
737                if bytes is not str:  # Python >= 3.0
738                    value = value.decode('ascii')
739            else:
740                value = self._cnx.escape_string(value)
741            return "'%s'" % value
742        if isinstance(value, float):
743            if isinf(value):
744                return "'-Infinity'" if value < 0 else "'Infinity'"
745            if isnan(value):
746                return "'NaN'"
747            return value
748        if isinstance(value, (int, long, Decimal, Literal)):
749            return value
750        if isinstance(value, list):
751            # Quote value as an ARRAY constructor. This is better than using
752            # an array literal because it carries the information that this is
753            # an array and not a string.  One issue with this syntax is that
754            # you need to add an explicit typecast when passing empty arrays.
755            # The ARRAY keyword is actually only necessary at the top level.
756            q = self._quote
757            return 'ARRAY[%s]' % ','.join(str(q(v)) for v in value)
758        if isinstance(value, tuple):
759            # Quote as a ROW constructor.  This is better than using a record
760            # literal because it carries the information that this is a record
761            # and not a string.  We don't use the keyword ROW in order to make
762            # this usable with the IN syntax as well.  It is only necessary
763            # when the records has a single column which is not really useful.
764            q = self._quote
765            return '(%s)' % ','.join(str(q(v)) for v in value)
766        try:
767            value = value.__pg_repr__()
768        except AttributeError:
769            raise InterfaceError(
770                'Do not know how to adapt type %s' % type(value))
771        if isinstance(value, (tuple, list)):
772            value = self._quote(value)
773        return value
774
775    def _quoteparams(self, string, parameters):
776        """Quote parameters.
777
778        This function works for both mappings and sequences.
779        """
780        if isinstance(parameters, dict):
781            parameters = _quotedict(parameters)
782            parameters.quote = self._quote
783        else:
784            parameters = tuple(map(self._quote, parameters))
785        return string % parameters
786
787    def _make_description(self, info):
788        """Make the description tuple for the given field info."""
789        name, typ, size, mod = info[1:]
790        type_code = self.type_cache[typ]
791        if mod > 0:
792            mod -= 4
793        if type_code == 'numeric':
794            precision, scale = mod >> 16, mod & 0xffff
795            size = precision
796        else:
797            if not size:
798                size = type_code.size
799            if size == -1:
800                size = mod
801            precision = scale = None
802        return CursorDescription(name, type_code,
803            None, size, precision, scale, None)
804
805    @property
806    def description(self):
807        """Read-only attribute describing the result columns."""
808        descr = self._description
809        if self._description is True:
810            make = self._make_description
811            descr = [make(info) for info in self._src.listinfo()]
812            self._description = descr
813        return descr
814
815    @property
816    def colnames(self):
817        """Unofficial convenience method for getting the column names."""
818        return [d[0] for d in self.description]
819
820    @property
821    def coltypes(self):
822        """Unofficial convenience method for getting the column types."""
823        return [d[1] for d in self.description]
824
825    def close(self):
826        """Close the cursor object."""
827        self._src.close()
828        self._description = None
829        self.rowcount = -1
830        self.lastrowid = None
831
832    def execute(self, operation, parameters=None):
833        """Prepare and execute a database operation (query or command)."""
834        # The parameters may also be specified as list of tuples to e.g.
835        # insert multiple rows in a single operation, but this kind of
836        # usage is deprecated.  We make several plausibility checks because
837        # tuples can also be passed with the meaning of ROW constructors.
838        if (parameters and isinstance(parameters, list)
839                and len(parameters) > 1
840                and all(isinstance(p, tuple) for p in parameters)
841                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
842            return self.executemany(operation, parameters)
843        else:
844            # not a list of tuples
845            return self.executemany(operation, [parameters])
846
847    def executemany(self, operation, seq_of_parameters):
848        """Prepare operation and execute it against a parameter sequence."""
849        if not seq_of_parameters:
850            # don't do anything without parameters
851            return
852        self._description = None
853        self.rowcount = -1
854        # first try to execute all queries
855        rowcount = 0
856        sql = "BEGIN"
857        try:
858            if not self._dbcnx._tnx:
859                try:
860                    self._cnx.source().execute(sql)
861                except DatabaseError:
862                    raise  # database provides error message
863                except Exception as err:
864                    raise _op_error("Can't start transaction")
865                self._dbcnx._tnx = True
866            for parameters in seq_of_parameters:
867                sql = operation
868                if parameters:
869                    sql = self._quoteparams(sql, parameters)
870                rows = self._src.execute(sql)
871                if rows:  # true if not DML
872                    rowcount += rows
873                else:
874                    self.rowcount = -1
875        except DatabaseError:
876            raise  # database provides error message
877        except Error as err:
878            raise _db_error(
879                "Error in '%s': '%s' " % (sql, err), InterfaceError)
880        except Exception as err:
881            raise _op_error("Internal error in '%s': %s" % (sql, err))
882        # then initialize result raw count and description
883        if self._src.resulttype == RESULT_DQL:
884            self._description = True  # fetch on demand
885            self.rowcount = self._src.ntuples
886            self.lastrowid = None
887            if self.build_row_factory:
888                self.row_factory = self.build_row_factory()
889        else:
890            self.rowcount = rowcount
891            self.lastrowid = self._src.oidstatus()
892        # return the cursor object, so you can write statements such as
893        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
894        return self
895
896    def fetchone(self):
897        """Fetch the next row of a query result set."""
898        res = self.fetchmany(1, False)
899        try:
900            return res[0]
901        except IndexError:
902            return None
903
904    def fetchall(self):
905        """Fetch all (remaining) rows of a query result."""
906        return self.fetchmany(-1, False)
907
908    def fetchmany(self, size=None, keep=False):
909        """Fetch the next set of rows of a query result.
910
911        The number of rows to fetch per call is specified by the
912        size parameter. If it is not given, the cursor's arraysize
913        determines the number of rows to be fetched. If you set
914        the keep parameter to true, this is kept as new arraysize.
915        """
916        if size is None:
917            size = self.arraysize
918        if keep:
919            self.arraysize = size
920        try:
921            result = self._src.fetch(size)
922        except DatabaseError:
923            raise
924        except Error as err:
925            raise _db_error(str(err))
926        typecast = self.type_cache.typecast
927        return [self.row_factory([typecast(value, typ)
928            for typ, value in zip(self.coltypes, row)]) for row in result]
929
930    def callproc(self, procname, parameters=None):
931        """Call a stored database procedure with the given name.
932
933        The sequence of parameters must contain one entry for each input
934        argument that the procedure expects. The result of the call is the
935        same as this input sequence; replacement of output and input/output
936        parameters in the return value is currently not supported.
937
938        The procedure may also provide a result set as output. These can be
939        requested through the standard fetch methods of the cursor.
940        """
941        n = parameters and len(parameters) or 0
942        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
943        self.execute(query, parameters)
944        return parameters
945
946    def copy_from(self, stream, table,
947            format=None, sep=None, null=None, size=None, columns=None):
948        """Copy data from an input stream to the specified table.
949
950        The input stream can be a file-like object with a read() method or
951        it can also be an iterable returning a row or multiple rows of input
952        on each iteration.
953
954        The format must be text, csv or binary. The sep option sets the
955        column separator (delimiter) used in the non binary formats.
956        The null option sets the textual representation of NULL in the input.
957
958        The size option sets the size of the buffer used when reading data
959        from file-like objects.
960
961        The copy operation can be restricted to a subset of columns. If no
962        columns are specified, all of them will be copied.
963        """
964        binary_format = format == 'binary'
965        try:
966            read = stream.read
967        except AttributeError:
968            if size:
969                raise ValueError("Size must only be set for file-like objects")
970            if binary_format:
971                input_type = bytes
972                type_name = 'byte strings'
973            else:
974                input_type = basestring
975                type_name = 'strings'
976
977            if isinstance(stream, basestring):
978                if not isinstance(stream, input_type):
979                    raise ValueError("The input must be %s" % type_name)
980                if not binary_format:
981                    if isinstance(stream, str):
982                        if not stream.endswith('\n'):
983                            stream += '\n'
984                    else:
985                        if not stream.endswith(b'\n'):
986                            stream += b'\n'
987
988                def chunks():
989                    yield stream
990
991            elif isinstance(stream, Iterable):
992
993                def chunks():
994                    for chunk in stream:
995                        if not isinstance(chunk, input_type):
996                            raise ValueError(
997                                "Input stream must consist of %s" % type_name)
998                        if isinstance(chunk, str):
999                            if not chunk.endswith('\n'):
1000                                chunk += '\n'
1001                        else:
1002                            if not chunk.endswith(b'\n'):
1003                                chunk += b'\n'
1004                        yield chunk
1005
1006            else:
1007                raise TypeError("Need an input stream to copy from")
1008        else:
1009            if size is None:
1010                size = 8192
1011            elif not isinstance(size, int):
1012                raise TypeError("The size option must be an integer")
1013            if size > 0:
1014
1015                def chunks():
1016                    while True:
1017                        buffer = read(size)
1018                        yield buffer
1019                        if not buffer or len(buffer) < size:
1020                            break
1021
1022            else:
1023
1024                def chunks():
1025                    yield read()
1026
1027        if not table or not isinstance(table, basestring):
1028            raise TypeError("Need a table to copy to")
1029        if table.lower().startswith('select'):
1030                raise ValueError("Must specify a table, not a query")
1031        else:
1032            table = '"%s"' % (table,)
1033        operation = ['copy %s' % (table,)]
1034        options = []
1035        params = []
1036        if format is not None:
1037            if not isinstance(format, basestring):
1038                raise TypeError("The frmat option must be be a string")
1039            if format not in ('text', 'csv', 'binary'):
1040                raise ValueError("Invalid format")
1041            options.append('format %s' % (format,))
1042        if sep is not None:
1043            if not isinstance(sep, basestring):
1044                raise TypeError("The sep option must be a string")
1045            if format == 'binary':
1046                raise ValueError(
1047                    "The sep option is not allowed with binary format")
1048            if len(sep) != 1:
1049                raise ValueError(
1050                    "The sep option must be a single one-byte character")
1051            options.append('delimiter %s')
1052            params.append(sep)
1053        if null is not None:
1054            if not isinstance(null, basestring):
1055                raise TypeError("The null option must be a string")
1056            options.append('null %s')
1057            params.append(null)
1058        if columns:
1059            if not isinstance(columns, basestring):
1060                columns = ','.join('"%s"' % (col,) for col in columns)
1061            operation.append('(%s)' % (columns,))
1062        operation.append("from stdin")
1063        if options:
1064            operation.append('(%s)' % ','.join(options))
1065        operation = ' '.join(operation)
1066
1067        putdata = self._src.putdata
1068        self.execute(operation, params)
1069
1070        try:
1071            for chunk in chunks():
1072                putdata(chunk)
1073        except BaseException as error:
1074            self.rowcount = -1
1075            # the following call will re-raise the error
1076            putdata(error)
1077        else:
1078            self.rowcount = putdata(None)
1079
1080        # return the cursor object, so you can chain operations
1081        return self
1082
1083    def copy_to(self, stream, table,
1084            format=None, sep=None, null=None, decode=None, columns=None):
1085        """Copy data from the specified table to an output stream.
1086
1087        The output stream can be a file-like object with a write() method or
1088        it can also be None, in which case the method will return a generator
1089        yielding a row on each iteration.
1090
1091        Output will be returned as byte strings unless you set decode to true.
1092
1093        Note that you can also use a select query instead of the table name.
1094
1095        The format must be text, csv or binary. The sep option sets the
1096        column separator (delimiter) used in the non binary formats.
1097        The null option sets the textual representation of NULL in the output.
1098
1099        The copy operation can be restricted to a subset of columns. If no
1100        columns are specified, all of them will be copied.
1101        """
1102        binary_format = format == 'binary'
1103        if stream is not None:
1104            try:
1105                write = stream.write
1106            except AttributeError:
1107                raise TypeError("Need an output stream to copy to")
1108        if not table or not isinstance(table, basestring):
1109            raise TypeError("Need a table to copy to")
1110        if table.lower().startswith('select'):
1111            if columns:
1112                raise ValueError("Columns must be specified in the query")
1113            table = '(%s)' % (table,)
1114        else:
1115            table = '"%s"' % (table,)
1116        operation = ['copy %s' % (table,)]
1117        options = []
1118        params = []
1119        if format is not None:
1120            if not isinstance(format, basestring):
1121                raise TypeError("The format option must be a string")
1122            if format not in ('text', 'csv', 'binary'):
1123                raise ValueError("Invalid format")
1124            options.append('format %s' % (format,))
1125        if sep is not None:
1126            if not isinstance(sep, basestring):
1127                raise TypeError("The sep option must be a string")
1128            if binary_format:
1129                raise ValueError(
1130                    "The sep option is not allowed with binary format")
1131            if len(sep) != 1:
1132                raise ValueError(
1133                    "The sep option must be a single one-byte character")
1134            options.append('delimiter %s')
1135            params.append(sep)
1136        if null is not None:
1137            if not isinstance(null, basestring):
1138                raise TypeError("The null option must be a string")
1139            options.append('null %s')
1140            params.append(null)
1141        if decode is None:
1142            if format == 'binary':
1143                decode = False
1144            else:
1145                decode = str is unicode
1146        else:
1147            if not isinstance(decode, (int, bool)):
1148                raise TypeError("The decode option must be a boolean")
1149            if decode and binary_format:
1150                raise ValueError(
1151                    "The decode option is not allowed with binary format")
1152        if columns:
1153            if not isinstance(columns, basestring):
1154                columns = ','.join('"%s"' % (col,) for col in columns)
1155            operation.append('(%s)' % (columns,))
1156
1157        operation.append("to stdout")
1158        if options:
1159            operation.append('(%s)' % ','.join(options))
1160        operation = ' '.join(operation)
1161
1162        getdata = self._src.getdata
1163        self.execute(operation, params)
1164
1165        def copy():
1166            self.rowcount = 0
1167            while True:
1168                row = getdata(decode)
1169                if isinstance(row, int):
1170                    if self.rowcount != row:
1171                        self.rowcount = row
1172                    break
1173                self.rowcount += 1
1174                yield row
1175
1176        if stream is None:
1177            # no input stream, return the generator
1178            return copy()
1179
1180        # write the rows to the file-like input stream
1181        for row in copy():
1182            write(row)
1183
1184        # return the cursor object, so you can chain operations
1185        return self
1186
1187    def __next__(self):
1188        """Return the next row (support for the iteration protocol)."""
1189        res = self.fetchone()
1190        if res is None:
1191            raise StopIteration
1192        return res
1193
1194    # Note that since Python 2.6 the iterator protocol uses __next()__
1195    # instead of next(), we keep it only for backward compatibility of pgdb.
1196    next = __next__
1197
1198    @staticmethod
1199    def nextset():
1200        """Not supported."""
1201        raise NotSupportedError("The nextset() method is not supported")
1202
1203    @staticmethod
1204    def setinputsizes(sizes):
1205        """Not supported."""
1206        pass  # unsupported, but silently passed
1207
1208    @staticmethod
1209    def setoutputsize(size, column=0):
1210        """Not supported."""
1211        pass  # unsupported, but silently passed
1212
1213    @staticmethod
1214    def row_factory(row):
1215        """Process rows before they are returned.
1216
1217        You can overwrite this statically with a custom row factory, or
1218        you can build a row factory dynamically with build_row_factory().
1219
1220        For example, you can create a Cursor class that returns rows as
1221        Python dictionaries like this:
1222
1223            class DictCursor(pgdb.Cursor):
1224
1225                def row_factory(self, row):
1226                    return {desc[0]: value
1227                        for desc, value in zip(self.description, row)}
1228
1229            cur = DictCursor(con)  # get one DictCursor instance or
1230            con.cursor_type = DictCursor  # always use DictCursor instances
1231        """
1232        raise NotImplementedError
1233
1234    def build_row_factory(self):
1235        """Build a row factory based on the current description.
1236
1237        This implementation builds a row factory for creating named tuples.
1238        You can overwrite this method if you want to dynamically create
1239        different row factories whenever the column description changes.
1240        """
1241        colnames = self.colnames
1242        if colnames:
1243            try:
1244                try:
1245                    return namedtuple('Row', colnames, rename=True)._make
1246                except TypeError:  # Python 2.6 and 3.0 do not support rename
1247                    colnames = [v if v.isalnum() else 'column_%d' % n
1248                             for n, v in enumerate(colnames)]
1249                    return namedtuple('Row', colnames)._make
1250            except ValueError:  # there is still a problem with the field names
1251                colnames = ['column_%d' % n for n in range(len(colnames))]
1252                return namedtuple('Row', colnames)._make
1253
1254
1255CursorDescription = namedtuple('CursorDescription',
1256    ['name', 'type_code', 'display_size', 'internal_size',
1257     'precision', 'scale', 'null_ok'])
1258
1259
1260### Connection Objects
1261
1262class Connection(object):
1263    """Connection object."""
1264
1265    # expose the exceptions as attributes on the connection object
1266    Error = Error
1267    Warning = Warning
1268    InterfaceError = InterfaceError
1269    DatabaseError = DatabaseError
1270    InternalError = InternalError
1271    OperationalError = OperationalError
1272    ProgrammingError = ProgrammingError
1273    IntegrityError = IntegrityError
1274    DataError = DataError
1275    NotSupportedError = NotSupportedError
1276
1277    def __init__(self, cnx):
1278        """Create a database connection object."""
1279        self._cnx = cnx  # connection
1280        self._tnx = False  # transaction state
1281        self.type_cache = TypeCache(cnx)
1282        self.cursor_type = Cursor
1283        try:
1284            self._cnx.source()
1285        except Exception:
1286            raise _op_error("Invalid connection")
1287
1288    def __enter__(self):
1289        """Enter the runtime context for the connection object.
1290
1291        The runtime context can be used for running transactions.
1292        """
1293        return self
1294
1295    def __exit__(self, et, ev, tb):
1296        """Exit the runtime context for the connection object.
1297
1298        This does not close the connection, but it ends a transaction.
1299        """
1300        if et is None and ev is None and tb is None:
1301            self.commit()
1302        else:
1303            self.rollback()
1304
1305    def close(self):
1306        """Close the connection object."""
1307        if self._cnx:
1308            if self._tnx:
1309                try:
1310                    self.rollback()
1311                except DatabaseError:
1312                    pass
1313            self._cnx.close()
1314            self._cnx = None
1315        else:
1316            raise _op_error("Connection has been closed")
1317
1318    def commit(self):
1319        """Commit any pending transaction to the database."""
1320        if self._cnx:
1321            if self._tnx:
1322                self._tnx = False
1323                try:
1324                    self._cnx.source().execute("COMMIT")
1325                except DatabaseError:
1326                    raise
1327                except Exception:
1328                    raise _op_error("Can't commit")
1329        else:
1330            raise _op_error("Connection has been closed")
1331
1332    def rollback(self):
1333        """Roll back to the start of any pending transaction."""
1334        if self._cnx:
1335            if self._tnx:
1336                self._tnx = False
1337                try:
1338                    self._cnx.source().execute("ROLLBACK")
1339                except DatabaseError:
1340                    raise
1341                except Exception:
1342                    raise _op_error("Can't rollback")
1343        else:
1344            raise _op_error("Connection has been closed")
1345
1346    def cursor(self):
1347        """Return a new cursor object using the connection."""
1348        if self._cnx:
1349            try:
1350                return self.cursor_type(self)
1351            except Exception:
1352                raise _op_error("Invalid connection")
1353        else:
1354            raise _op_error("Connection has been closed")
1355
1356    if shortcutmethods:  # otherwise do not implement and document this
1357
1358        def execute(self, operation, params=None):
1359            """Shortcut method to run an operation on an implicit cursor."""
1360            cursor = self.cursor()
1361            cursor.execute(operation, params)
1362            return cursor
1363
1364        def executemany(self, operation, param_seq):
1365            """Shortcut method to run an operation against a sequence."""
1366            cursor = self.cursor()
1367            cursor.executemany(operation, param_seq)
1368            return cursor
1369
1370
1371### Module Interface
1372
1373_connect_ = connect
1374
1375def connect(dsn=None,
1376        user=None, password=None,
1377        host=None, database=None):
1378    """Connect to a database."""
1379    # first get params from DSN
1380    dbport = -1
1381    dbhost = ""
1382    dbbase = ""
1383    dbuser = ""
1384    dbpasswd = ""
1385    dbopt = ""
1386    try:
1387        params = dsn.split(":")
1388        dbhost = params[0]
1389        dbbase = params[1]
1390        dbuser = params[2]
1391        dbpasswd = params[3]
1392        dbopt = params[4]
1393    except (AttributeError, IndexError, TypeError):
1394        pass
1395
1396    # override if necessary
1397    if user is not None:
1398        dbuser = user
1399    if password is not None:
1400        dbpasswd = password
1401    if database is not None:
1402        dbbase = database
1403    if host is not None:
1404        try:
1405            params = host.split(":")
1406            dbhost = params[0]
1407            dbport = int(params[1])
1408        except (AttributeError, IndexError, TypeError, ValueError):
1409            pass
1410
1411    # empty host is localhost
1412    if dbhost == "":
1413        dbhost = None
1414    if dbuser == "":
1415        dbuser = None
1416
1417    # open the connection
1418    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
1419    return Connection(cnx)
1420
1421
1422### Types Handling
1423
1424class Type(frozenset):
1425    """Type class for a couple of PostgreSQL data types.
1426
1427    PostgreSQL is object-oriented: types are dynamic.
1428    We must thus use type names as internal type codes.
1429    """
1430
1431    def __new__(cls, values):
1432        if isinstance(values, basestring):
1433            values = values.split()
1434        return super(Type, cls).__new__(cls, values)
1435
1436    def __eq__(self, other):
1437        if isinstance(other, basestring):
1438            if other.startswith('_'):
1439                other = other[1:]
1440            return other in self
1441        else:
1442            return super(Type, self).__eq__(other)
1443
1444    def __ne__(self, other):
1445        if isinstance(other, basestring):
1446            if other.startswith('_'):
1447                other = other[1:]
1448            return other not in self
1449        else:
1450            return super(Type, self).__ne__(other)
1451
1452
1453class ArrayType:
1454    """Type class for PostgreSQL array types."""
1455
1456    def __eq__(self, other):
1457        if isinstance(other, basestring):
1458            return other.startswith('_')
1459        else:
1460            return isinstance(other, ArrayType)
1461
1462    def __ne__(self, other):
1463        if isinstance(other, basestring):
1464            return not other.startswith('_')
1465        else:
1466            return not isinstance(other, ArrayType)
1467
1468
1469class RecordType:
1470    """Type class for PostgreSQL record types."""
1471
1472    def __eq__(self, other):
1473        if isinstance(other, TypeCode):
1474            return other.type == 'c'
1475        elif isinstance(other, basestring):
1476            return other == 'record'
1477        else:
1478            return isinstance(other, RecordType)
1479
1480    def __ne__(self, other):
1481        if isinstance(other, TypeCode):
1482            return other.type != 'c'
1483        elif isinstance(other, basestring):
1484            return other != 'record'
1485        else:
1486            return not isinstance(other, RecordType)
1487
1488
1489# Mandatory type objects defined by DB-API 2 specs:
1490
1491STRING = Type('char bpchar name text varchar')
1492BINARY = Type('bytea')
1493NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1494DATETIME = Type('date time timetz timestamp timestamptz interval'
1495    ' abstime reltime')  # these are very old
1496ROWID = Type('oid')
1497
1498
1499# Additional type objects (more specific):
1500
1501BOOL = Type('bool')
1502SMALLINT = Type('int2')
1503INTEGER = Type('int2 int4 int8 serial')
1504LONG = Type('int8')
1505FLOAT = Type('float4 float8')
1506NUMERIC = Type('numeric')
1507MONEY = Type('money')
1508DATE = Type('date')
1509TIME = Type('time timetz')
1510TIMESTAMP = Type('timestamp timestamptz')
1511INTERVAL = Type('interval')
1512HSTORE = Type('hstore')
1513JSON = Type('json jsonb')
1514
1515# Type object for arrays (also equate to their base types):
1516
1517ARRAY = ArrayType()
1518
1519# Type object for records (encompassing all composite types):
1520
1521RECORD = RecordType()
1522
1523
1524# Mandatory type helpers defined by DB-API 2 specs:
1525
1526def Date(year, month, day):
1527    """Construct an object holding a date value."""
1528    return date(year, month, day)
1529
1530
1531def Time(hour, minute=0, second=0, microsecond=0):
1532    """Construct an object holding a time value."""
1533    return time(hour, minute, second, microsecond)
1534
1535
1536def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1537    """Construct an object holding a time stamp value."""
1538    return datetime(year, month, day, hour, minute, second, microsecond)
1539
1540
1541def DateFromTicks(ticks):
1542    """Construct an object holding a date value from the given ticks value."""
1543    return Date(*localtime(ticks)[:3])
1544
1545
1546def TimeFromTicks(ticks):
1547    """Construct an object holding a time value from the given ticks value."""
1548    return Time(*localtime(ticks)[3:6])
1549
1550
1551def TimestampFromTicks(ticks):
1552    """Construct an object holding a time stamp from the given ticks value."""
1553    return Timestamp(*localtime(ticks)[:6])
1554
1555
1556class Binary(bytes):
1557    """Construct an object capable of holding a binary (long) string value."""
1558
1559
1560# Additional type helpers for PyGreSQL:
1561
1562def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0):
1563    """Construct an object holding a time inverval value."""
1564    return timedelta(days, hours=hours, minutes=minutes, seconds=seconds,
1565        microseconds=microseconds)
1566
1567
1568class Hstore(dict):
1569    """Wrapper class for marking hstore values."""
1570
1571    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
1572
1573    @classmethod
1574    def _quote(cls, s):
1575        if s is None:
1576            return 'NULL'
1577        if not s:
1578            return '""'
1579        s = s.replace('"', '\\"')
1580        if cls._re_quote.search(s):
1581            s = '"%s"' % s
1582        return s
1583
1584    def __str__(self):
1585        q = self._quote
1586        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
1587
1588
1589class Json:
1590    """Construct a wrapper for holding an object serializable to JSON."""
1591
1592    def __init__(self, obj, encode=None):
1593        self.obj = obj
1594        self.encode = encode or jsonencode
1595
1596    def __str__(self):
1597        obj = self.obj
1598        if isinstance(obj, basestring):
1599            return obj
1600        return self.encode(obj)
1601
1602
1603class Literal:
1604    """Construct a wrapper for holding a literal SQL string."""
1605
1606    def __init__(self, sql):
1607        self.sql = sql
1608
1609    def __str__(self):
1610        return self.sql
1611
1612    __pg_repr__ = __str__
1613
1614# If run as script, print some information:
1615
1616if __name__ == '__main__':
1617    print('PyGreSQL version', version)
1618    print('')
1619    print(__doc__)
Note: See TracBrowser for help on using the repository browser.