source: trunk/pgdb.py @ 798

Last change on this file since 798 was 798, checked in by cito, 3 years ago

Port type cache and typecasting from pgdb to pg

So far, the typecasting in the classic module was been only done by
the C extension module and was not extensible through typecasting
functions in Python. This has now been made extensible by adding
a cast hook to the C extension module which has been hooked up to
a new type cache object that holds information on the types and the
associated typecast functions. All of this works very similar to the
pgdb module now, except that the basic types are still handled by
the C extension module and the Python typecast functions are only
called via the hook for types which are not supported internally.

Also added tests and a chapter on the type cache in the documentation,
and cleaned up the error messages in the C extension module.

  • Property svn:keywords set to Id
File size: 43.3 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 798 2016-01-30 19:55:18Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from json import loads as jsondecode, dumps as jsonencode
76
77try:
78    long
79except NameError:  # Python >= 3.0
80    long = int
81
82try:
83    unicode
84except NameError:  # Python >= 3.0
85    unicode = str
86
87try:
88    basestring
89except NameError:  # Python >= 3.0
90    basestring = (str, bytes)
91
92from collections import Iterable
93try:
94    from collections import OrderedDict
95except ImportError:  # Python 2.6 or 3.0
96    try:
97        from ordereddict import OrderedDict
98    except Exception:
99        def OrderedDict(*args):
100            raise NotSupportedError('OrderedDict is not supported')
101
102
103### Module Constants
104
105# compliant with DB API 2.0
106apilevel = '2.0'
107
108# module may be shared, but not connections
109threadsafety = 1
110
111# this module use extended python format codes
112paramstyle = 'pyformat'
113
114# shortcut methods have been excluded from DB API 2 and
115# are not recommended by the DB SIG, but they can be handy
116shortcutmethods = 1
117
118
119### Internal Type Handling
120
121def decimal_type(decimal_type=None):
122    """Get or set global type to be used for decimal values.
123
124    Note that connections cache cast functions. To be sure a global change
125    is picked up by a running connection, call con.type_cache.reset_typecast().
126    """
127    global Decimal
128    if decimal_type is not None:
129        Decimal = decimal_type
130        set_typecast('numeric', decimal_type)
131    return Decimal
132
133
134def cast_bool(value):
135    """Cast boolean value in database format to bool."""
136    if value:
137        return value[0] in ('t', 'T')
138
139
140def cast_money(value):
141    """Cast money value in database format to Decimal."""
142    if value:
143        value = value.replace('(', '-')
144        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
145
146
147def cast_int2vector(value):
148    """Cast an int2vector value."""
149    return [int(v) for v in value.split()]
150
151
152class Typecasts(dict):
153    """Dictionary mapping database types to typecast functions.
154
155    The cast functions get passed the string representation of a value in
156    the database which they need to convert to a Python object.  The
157    passed string will never be None since NULL values are already be
158    handled before the cast function is called.
159    """
160
161    # the default cast functions
162    # (str functions are ignored but have been added for faster access)
163    defaults = {'char': str, 'bpchar': str, 'name': str,
164        'text': str, 'varchar': str,
165        'bool': cast_bool, 'bytea': unescape_bytea,
166        'int2': int, 'int4': int, 'serial': int,
167        'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
168        'oid': long, 'oid8': long,
169        'float4': float, 'float8': float,
170        'numeric': Decimal, 'money': cast_money,
171        'int2vector': cast_int2vector,
172        'anyarray': cast_array, 'record': cast_record}
173
174    def __missing__(self, typ):
175        """Create a cast function if it is not cached.
176
177        Note that this class never raises a KeyError,
178        but returns None when no special cast function exists.
179        """
180        if not isinstance(typ, str):
181            raise TypeError('Invalid type: %s' % typ)
182        cast = self.defaults.get(typ)
183        if cast:
184            # store default for faster access
185            self[typ] = cast
186        elif typ.startswith('_'):
187            # create array cast
188            base_cast = self[typ[1:]]
189            cast = self.create_array_cast(base_cast)
190            if base_cast:
191                # store only if base type exists
192                self[typ] = cast
193        return cast
194
195    def get(self, typ, default=None):
196        """Get the typecast function for the given database type."""
197        return self[typ] or default
198
199    def set(self, typ, cast):
200        """Set a typecast function for the specified database type(s)."""
201        if isinstance(typ, basestring):
202            typ = [typ]
203        if cast is None:
204            for t in typ:
205                self.pop(t, None)
206                self.pop('_%s' % t, None)
207        else:
208            if not callable(cast):
209                raise TypeError("Cast parameter must be callable")
210            for t in typ:
211                self[t] = cast
212                self.pop('_%s % t', None)
213
214    def reset(self, typ=None):
215        """Reset the typecasts for the specified type(s) to their defaults.
216
217        When no type is specified, all typecasts will be reset.
218        """
219        defaults = self.defaults
220        if typ is None:
221            self.clear()
222            self.update(defaults)
223        else:
224            if isinstance(typ, basestring):
225                typ = [typ]
226            for t in typ:
227                self.set(t, defaults.get(t))
228
229    def create_array_cast(self, cast):
230        """Create an array typecast for the given base cast."""
231        return lambda v: cast_array(v, cast)
232
233    def create_record_cast(self, name, fields, casts):
234        """Create a named record typecast for the given fields and casts."""
235        record = namedtuple(name, fields)
236        return lambda v: record(*cast_record(v, casts))
237
238
239_typecasts = Typecasts()  # this is the global typecast dictionary
240
241
242def get_typecast(typ):
243    """Get the global typecast function for the given database type(s)."""
244    return _typecasts.get(typ)
245
246
247def set_typecast(typ, cast):
248    """Set a global typecast function for the given database type(s).
249
250    Note that connections cache cast functions. To be sure a global change
251    is picked up by a running connection, call con.type_cache.reset_typecast().
252    """
253    _typecasts.set(typ, cast)
254
255
256def reset_typecast(typ=None):
257    """Reset the global typecasts for the given type(s) to their default.
258
259    When no type is specified, all typecasts will be reset.
260
261    Note that connections cache cast functions. To be sure a global change
262    is picked up by a running connection, call con.type_cache.reset_typecast().
263    """
264    _typecasts.reset(typ)
265
266
267class LocalTypecasts(Typecasts):
268    """Map typecasts, including local composite types, to cast functions."""
269
270    defaults = _typecasts
271
272    def __missing__(self, typ):
273        """Create a cast function if it is not cached."""
274        if typ.startswith('_'):
275            base_cast = self[typ[1:]]
276            cast = self.create_array_cast(base_cast)
277            if base_cast:
278                self[typ] = cast
279        else:
280            cast = self.defaults.get(typ)
281            if cast:
282                self[typ] = cast
283            else:
284                fields = self.get_fields(typ)
285                if fields:
286                    casts = [self[field.type] for field in fields]
287                    fields = [field.name for field in fields]
288                    cast = self.create_record_cast(typ, fields, casts)
289                    self[typ] = cast
290        return cast
291
292    def get_fields(self, typ):
293        """Return the fields for the given record type.
294
295        This method will be replaced with a method that looks up the fields
296        using the type cache of the connection.
297        """
298        return []
299
300
301class TypeCode(str):
302    """Class representing the type_code used by the DB-API 2.0.
303
304    TypeCode objects are strings equal to the PostgreSQL type name,
305    but carry some additional information.
306    """
307
308    @classmethod
309    def create(cls, oid, name, len, type, category, delim, relid):
310        """Create a type code for a PostgreSQL data type."""
311        self = cls(name)
312        self.oid = oid
313        self.len = len
314        self.type = type
315        self.category = category
316        self.delim = delim
317        self.relid = relid
318        return self
319
320FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
321
322
323class TypeCache(dict):
324    """Cache for database types.
325
326    This cache maps type OIDs and names to TypeCode strings containing
327    important information on the associated database type.
328    """
329
330    def __init__(self, cnx):
331        """Initialize type cache for connection."""
332        super(TypeCache, self).__init__()
333        self._escape_string = cnx.escape_string
334        self._src = cnx.source()
335        self._typecasts = LocalTypecasts()
336        self._typecasts.get_fields = self.get_fields
337
338    def __missing__(self, key):
339        """Get the type info from the database if it is not cached."""
340        if isinstance(key, int):
341            oid = key
342        else:
343            if '.' not in key and '"' not in key:
344                key = '"%s"' % key
345            oid = "'%s'::regtype" % self._escape_string(key)
346        try:
347            self._src.execute("SELECT oid, typname,"
348                 " typlen, typtype, typcategory, typdelim, typrelid"
349                " FROM pg_type WHERE oid=%s" % oid)
350        except ProgrammingError:
351            res = None
352        else:
353            res = self._src.fetch(1)
354        if not res:
355            raise KeyError('Type %s could not be found' % key)
356        res = res[0]
357        type_code = TypeCode.create(int(res[0]), res[1],
358            int(res[2]), res[3], res[4], res[5], int(res[6]))
359        self[type_code.oid] = self[str(type_code)] = type_code
360        return type_code
361
362    def get(self, key, default=None):
363        """Get the type even if it is not cached."""
364        try:
365            return self[key]
366        except KeyError:
367            return default
368
369    def get_fields(self, typ):
370        """Get the names and types of the fields of composite types."""
371        if not isinstance(typ, TypeCode):
372            typ = self.get(typ)
373            if not typ:
374                return None
375        if not typ.relid:
376            return None  # this type is not composite
377        self._src.execute("SELECT attname, atttypid"
378            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
379            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
380        return [FieldInfo(name, self.get(int(oid)))
381            for name, oid in self._src.fetch(-1)]
382
383    def get_typecast(self, typ):
384        """Get the typecast function for the given database type."""
385        return self._typecasts.get(typ)
386
387    def set_typecast(self, typ, cast):
388        """Set a typecast function for the specified database type(s)."""
389        self._typecasts.set(typ, cast)
390
391    def reset_typecast(self, typ=None):
392        """Reset the typecast function for the specified database type(s)."""
393        self._typecasts.reset(typ)
394
395    def typecast(self, value, typ):
396        """Cast the given value according to the given database type."""
397        if value is None:
398            # for NULL values, no typecast is necessary
399            return None
400        cast = self.get_typecast(typ)
401        if not cast or cast is str:
402            # no typecast is necessary
403            return value
404        return cast(value)
405
406
407class _quotedict(dict):
408    """Dictionary with auto quoting of its items.
409
410    The quote attribute must be set to the desired quote function.
411    """
412
413    def __getitem__(self, key):
414        return self.quote(super(_quotedict, self).__getitem__(key))
415
416
417### Error messages
418
419def _db_error(msg, cls=DatabaseError):
420    """Return DatabaseError with empty sqlstate attribute."""
421    error = cls(msg)
422    error.sqlstate = None
423    return error
424
425
426def _op_error(msg):
427    """Return OperationalError."""
428    return _db_error(msg, OperationalError)
429
430
431### Cursor Object
432
433class Cursor(object):
434    """Cursor object."""
435
436    def __init__(self, dbcnx):
437        """Create a cursor object for the database connection."""
438        self.connection = self._dbcnx = dbcnx
439        self._cnx = dbcnx._cnx
440        self.type_cache = dbcnx.type_cache
441        self._src = self._cnx.source()
442        # the official attribute for describing the result columns
443        self._description = None
444        if self.row_factory is Cursor.row_factory:
445            # the row factory needs to be determined dynamically
446            self.row_factory = None
447        else:
448            self.build_row_factory = None
449        self.rowcount = -1
450        self.arraysize = 1
451        self.lastrowid = None
452
453    def __iter__(self):
454        """Make cursor compatible to the iteration protocol."""
455        return self
456
457    def __enter__(self):
458        """Enter the runtime context for the cursor object."""
459        return self
460
461    def __exit__(self, et, ev, tb):
462        """Exit the runtime context for the cursor object."""
463        self.close()
464
465    def _quote(self, value):
466        """Quote value depending on its type."""
467        if value is None:
468            return 'NULL'
469        if isinstance(value, (datetime, date, time, timedelta, Json)):
470            value = str(value)
471        if isinstance(value, basestring):
472            if isinstance(value, Binary):
473                value = self._cnx.escape_bytea(value)
474                if bytes is not str:  # Python >= 3.0
475                    value = value.decode('ascii')
476            else:
477                value = self._cnx.escape_string(value)
478            return "'%s'" % value
479        if isinstance(value, float):
480            if isinf(value):
481                return "'-Infinity'" if value < 0 else "'Infinity'"
482            if isnan(value):
483                return "'NaN'"
484            return value
485        if isinstance(value, (int, long, Decimal)):
486            return value
487        if isinstance(value, list):
488            # Quote value as an ARRAY constructor. This is better than using
489            # an array literal because it carries the information that this is
490            # an array and not a string.  One issue with this syntax is that
491            # you need to add an explicit typecast when passing empty arrays.
492            # The ARRAY keyword is actually only necessary at the top level.
493            q = self._quote
494            return 'ARRAY[%s]' % ','.join(str(q(v)) for v in value)
495        if isinstance(value, tuple):
496            # Quote as a ROW constructor.  This is better than using a record
497            # literal because it carries the information that this is a record
498            # and not a string.  We don't use the keyword ROW in order to make
499            # this usable with the IN synntax as well.  It is only necessary
500            # when the records has a single column which is not really useful.
501            q = self._quote
502            return '(%s)' % ','.join(str(q(v)) for v in value)
503        try:
504            value = value.__pg_repr__()
505            if isinstance(value, (tuple, list)):
506                value = self._quote(value)
507            return value
508        except AttributeError:
509            raise InterfaceError(
510                'Do not know how to adapt type %s' % type(value))
511
512    def _quoteparams(self, string, parameters):
513        """Quote parameters.
514
515        This function works for both mappings and sequences.
516        """
517        if isinstance(parameters, dict):
518            parameters = _quotedict(parameters)
519            parameters.quote = self._quote
520        else:
521            parameters = tuple(map(self._quote, parameters))
522        return string % parameters
523
524    def _make_description(self, info):
525        """Make the description tuple for the given field info."""
526        name, typ, size, mod = info[1:]
527        type_code = self.type_cache[typ]
528        if mod > 0:
529            mod -= 4
530        if type_code == 'numeric':
531            precision, scale = mod >> 16, mod & 0xffff
532            size = precision
533        else:
534            if not size:
535                size = type_info.size
536            if size == -1:
537                size = mod
538            precision = scale = None
539        return CursorDescription(name, type_code,
540            None, size, precision, scale, None)
541
542    @property
543    def description(self):
544        """Read-only attribute describing the result columns."""
545        descr = self._description
546        if self._description is True:
547            make = self._make_description
548            descr = [make(info) for info in self._src.listinfo()]
549            self._description = descr
550        return descr
551
552    @property
553    def colnames(self):
554        """Unofficial convenience method for getting the column names."""
555        return [d[0] for d in self.description]
556
557    @property
558    def coltypes(self):
559        """Unofficial convenience method for getting the column types."""
560        return [d[1] for d in self.description]
561
562    def close(self):
563        """Close the cursor object."""
564        self._src.close()
565        self._description = None
566        self.rowcount = -1
567        self.lastrowid = None
568
569    def execute(self, operation, parameters=None):
570        """Prepare and execute a database operation (query or command)."""
571        # The parameters may also be specified as list of tuples to e.g.
572        # insert multiple rows in a single operation, but this kind of
573        # usage is deprecated.  We make several plausibility checks because
574        # tuples can also be passed with the meaning of ROW constructors.
575        if (parameters and isinstance(parameters, list)
576                and len(parameters) > 1
577                and all(isinstance(p, tuple) for p in parameters)
578                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
579            return self.executemany(operation, parameters)
580        else:
581            # not a list of tuples
582            return self.executemany(operation, [parameters])
583
584    def executemany(self, operation, seq_of_parameters):
585        """Prepare operation and execute it against a parameter sequence."""
586        if not seq_of_parameters:
587            # don't do anything without parameters
588            return
589        self._description = None
590        self.rowcount = -1
591        # first try to execute all queries
592        rowcount = 0
593        sql = "BEGIN"
594        try:
595            if not self._dbcnx._tnx:
596                try:
597                    self._cnx.source().execute(sql)
598                except DatabaseError:
599                    raise  # database provides error message
600                except Exception as err:
601                    raise _op_error("Can't start transaction")
602                self._dbcnx._tnx = True
603            for parameters in seq_of_parameters:
604                sql = operation
605                if parameters:
606                    sql = self._quoteparams(sql, parameters)
607                rows = self._src.execute(sql)
608                if rows:  # true if not DML
609                    rowcount += rows
610                else:
611                    self.rowcount = -1
612        except DatabaseError:
613            raise  # database provides error message
614        except Error as err:
615            raise _db_error(
616                "Error in '%s': '%s' " % (sql, err), InterfaceError)
617        except Exception as err:
618            raise _op_error("Internal error in '%s': %s" % (sql, err))
619        # then initialize result raw count and description
620        if self._src.resulttype == RESULT_DQL:
621            self._description = True  # fetch on demand
622            self.rowcount = self._src.ntuples
623            self.lastrowid = None
624            if self.build_row_factory:
625                self.row_factory = self.build_row_factory()
626        else:
627            self.rowcount = rowcount
628            self.lastrowid = self._src.oidstatus()
629        # return the cursor object, so you can write statements such as
630        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
631        return self
632
633    def fetchone(self):
634        """Fetch the next row of a query result set."""
635        res = self.fetchmany(1, False)
636        try:
637            return res[0]
638        except IndexError:
639            return None
640
641    def fetchall(self):
642        """Fetch all (remaining) rows of a query result."""
643        return self.fetchmany(-1, False)
644
645    def fetchmany(self, size=None, keep=False):
646        """Fetch the next set of rows of a query result.
647
648        The number of rows to fetch per call is specified by the
649        size parameter. If it is not given, the cursor's arraysize
650        determines the number of rows to be fetched. If you set
651        the keep parameter to true, this is kept as new arraysize.
652        """
653        if size is None:
654            size = self.arraysize
655        if keep:
656            self.arraysize = size
657        try:
658            result = self._src.fetch(size)
659        except DatabaseError:
660            raise
661        except Error as err:
662            raise _db_error(str(err))
663        typecast = self.type_cache.typecast
664        return [self.row_factory([typecast(value, typ)
665            for typ, value in zip(self.coltypes, row)]) for row in result]
666
667    def callproc(self, procname, parameters=None):
668        """Call a stored database procedure with the given name.
669
670        The sequence of parameters must contain one entry for each input
671        argument that the procedure expects. The result of the call is the
672        same as this input sequence; replacement of output and input/output
673        parameters in the return value is currently not supported.
674
675        The procedure may also provide a result set as output. These can be
676        requested through the standard fetch methods of the cursor.
677        """
678        n = parameters and len(parameters) or 0
679        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
680        self.execute(query, parameters)
681        return parameters
682
683    def copy_from(self, stream, table,
684            format=None, sep=None, null=None, size=None, columns=None):
685        """Copy data from an input stream to the specified table.
686
687        The input stream can be a file-like object with a read() method or
688        it can also be an iterable returning a row or multiple rows of input
689        on each iteration.
690
691        The format must be text, csv or binary. The sep option sets the
692        column separator (delimiter) used in the non binary formats.
693        The null option sets the textual representation of NULL in the input.
694
695        The size option sets the size of the buffer used when reading data
696        from file-like objects.
697
698        The copy operation can be restricted to a subset of columns. If no
699        columns are specified, all of them will be copied.
700        """
701        binary_format = format == 'binary'
702        try:
703            read = stream.read
704        except AttributeError:
705            if size:
706                raise ValueError("Size must only be set for file-like objects")
707            if binary_format:
708                input_type = bytes
709                type_name = 'byte strings'
710            else:
711                input_type = basestring
712                type_name = 'strings'
713
714            if isinstance(stream, basestring):
715                if not isinstance(stream, input_type):
716                    raise ValueError("The input must be %s" % type_name)
717                if not binary_format:
718                    if isinstance(stream, str):
719                        if not stream.endswith('\n'):
720                            stream += '\n'
721                    else:
722                        if not stream.endswith(b'\n'):
723                            stream += b'\n'
724
725                def chunks():
726                    yield stream
727
728            elif isinstance(stream, Iterable):
729
730                def chunks():
731                    for chunk in stream:
732                        if not isinstance(chunk, input_type):
733                            raise ValueError(
734                                "Input stream must consist of %s" % type_name)
735                        if isinstance(chunk, str):
736                            if not chunk.endswith('\n'):
737                                chunk += '\n'
738                        else:
739                            if not chunk.endswith(b'\n'):
740                                chunk += b'\n'
741                        yield chunk
742
743            else:
744                raise TypeError("Need an input stream to copy from")
745        else:
746            if size is None:
747                size = 8192
748            elif not isinstance(size, int):
749                raise TypeError("The size option must be an integer")
750            if size > 0:
751
752                def chunks():
753                    while True:
754                        buffer = read(size)
755                        yield buffer
756                        if not buffer or len(buffer) < size:
757                            break
758
759            else:
760
761                def chunks():
762                    yield read()
763
764        if not table or not isinstance(table, basestring):
765            raise TypeError("Need a table to copy to")
766        if table.lower().startswith('select'):
767                raise ValueError("Must specify a table, not a query")
768        else:
769            table = '"%s"' % (table,)
770        operation = ['copy %s' % (table,)]
771        options = []
772        params = []
773        if format is not None:
774            if not isinstance(format, basestring):
775                raise TypeError("The frmat option must be be a string")
776            if format not in ('text', 'csv', 'binary'):
777                raise ValueError("Invalid format")
778            options.append('format %s' % (format,))
779        if sep is not None:
780            if not isinstance(sep, basestring):
781                raise TypeError("The sep option must be a string")
782            if format == 'binary':
783                raise ValueError(
784                    "The sep option is not allowed with binary format")
785            if len(sep) != 1:
786                raise ValueError(
787                    "The sep option must be a single one-byte character")
788            options.append('delimiter %s')
789            params.append(sep)
790        if null is not None:
791            if not isinstance(null, basestring):
792                raise TypeError("The null option must be a string")
793            options.append('null %s')
794            params.append(null)
795        if columns:
796            if not isinstance(columns, basestring):
797                columns = ','.join('"%s"' % (col,) for col in columns)
798            operation.append('(%s)' % (columns,))
799        operation.append("from stdin")
800        if options:
801            operation.append('(%s)' % ','.join(options))
802        operation = ' '.join(operation)
803
804        putdata = self._src.putdata
805        self.execute(operation, params)
806
807        try:
808            for chunk in chunks():
809                putdata(chunk)
810        except BaseException as error:
811            self.rowcount = -1
812            # the following call will re-raise the error
813            putdata(error)
814        else:
815            self.rowcount = putdata(None)
816
817        # return the cursor object, so you can chain operations
818        return self
819
820    def copy_to(self, stream, table,
821            format=None, sep=None, null=None, decode=None, columns=None):
822        """Copy data from the specified table to an output stream.
823
824        The output stream can be a file-like object with a write() method or
825        it can also be None, in which case the method will return a generator
826        yielding a row on each iteration.
827
828        Output will be returned as byte strings unless you set decode to true.
829
830        Note that you can also use a select query instead of the table name.
831
832        The format must be text, csv or binary. The sep option sets the
833        column separator (delimiter) used in the non binary formats.
834        The null option sets the textual representation of NULL in the output.
835
836        The copy operation can be restricted to a subset of columns. If no
837        columns are specified, all of them will be copied.
838        """
839        binary_format = format == 'binary'
840        if stream is not None:
841            try:
842                write = stream.write
843            except AttributeError:
844                raise TypeError("Need an output stream to copy to")
845        if not table or not isinstance(table, basestring):
846            raise TypeError("Need a table to copy to")
847        if table.lower().startswith('select'):
848            if columns:
849                raise ValueError("Columns must be specified in the query")
850            table = '(%s)' % (table,)
851        else:
852            table = '"%s"' % (table,)
853        operation = ['copy %s' % (table,)]
854        options = []
855        params = []
856        if format is not None:
857            if not isinstance(format, basestring):
858                raise TypeError("The format option must be a string")
859            if format not in ('text', 'csv', 'binary'):
860                raise ValueError("Invalid format")
861            options.append('format %s' % (format,))
862        if sep is not None:
863            if not isinstance(sep, basestring):
864                raise TypeError("The sep option must be a string")
865            if binary_format:
866                raise ValueError(
867                    "The sep option is not allowed with binary format")
868            if len(sep) != 1:
869                raise ValueError(
870                    "The sep option must be a single one-byte character")
871            options.append('delimiter %s')
872            params.append(sep)
873        if null is not None:
874            if not isinstance(null, basestring):
875                raise TypeError("The null option must be a string")
876            options.append('null %s')
877            params.append(null)
878        if decode is None:
879            if format == 'binary':
880                decode = False
881            else:
882                decode = str is unicode
883        else:
884            if not isinstance(decode, (int, bool)):
885                raise TypeError("The decode option must be a boolean")
886            if decode and binary_format:
887                raise ValueError(
888                    "The decode option is not allowed with binary format")
889        if columns:
890            if not isinstance(columns, basestring):
891                columns = ','.join('"%s"' % (col,) for col in columns)
892            operation.append('(%s)' % (columns,))
893
894        operation.append("to stdout")
895        if options:
896            operation.append('(%s)' % ','.join(options))
897        operation = ' '.join(operation)
898
899        getdata = self._src.getdata
900        self.execute(operation, params)
901
902        def copy():
903            self.rowcount = 0
904            while True:
905                row = getdata(decode)
906                if isinstance(row, int):
907                    if self.rowcount != row:
908                        self.rowcount = row
909                    break
910                self.rowcount += 1
911                yield row
912
913        if stream is None:
914            # no input stream, return the generator
915            return copy()
916
917        # write the rows to the file-like input stream
918        for row in copy():
919            write(row)
920
921        # return the cursor object, so you can chain operations
922        return self
923
924    def __next__(self):
925        """Return the next row (support for the iteration protocol)."""
926        res = self.fetchone()
927        if res is None:
928            raise StopIteration
929        return res
930
931    # Note that since Python 2.6 the iterator protocol uses __next()__
932    # instead of next(), we keep it only for backward compatibility of pgdb.
933    next = __next__
934
935    @staticmethod
936    def nextset():
937        """Not supported."""
938        raise NotSupportedError("The nextset() method is not supported")
939
940    @staticmethod
941    def setinputsizes(sizes):
942        """Not supported."""
943        pass  # unsupported, but silently passed
944
945    @staticmethod
946    def setoutputsize(size, column=0):
947        """Not supported."""
948        pass  # unsupported, but silently passed
949
950    @staticmethod
951    def row_factory(row):
952        """Process rows before they are returned.
953
954        You can overwrite this statically with a custom row factory, or
955        you can build a row factory dynamically with build_row_factory().
956
957        For example, you can create a Cursor class that returns rows as
958        Python dictionaries like this:
959
960            class DictCursor(pgdb.Cursor):
961
962                def row_factory(self, row):
963                    return {desc[0]: value
964                        for desc, value in zip(self.description, row)}
965
966            cur = DictCursor(con)  # get one DictCursor instance or
967            con.cursor_type = DictCursor  # always use DictCursor instances
968        """
969        raise NotImplementedError
970
971    def build_row_factory(self):
972        """Build a row factory based on the current description.
973
974        This implementation builds a row factory for creating named tuples.
975        You can overwrite this method if you want to dynamically create
976        different row factories whenever the column description changes.
977        """
978        colnames = self.colnames
979        if colnames:
980            try:
981                try:
982                    return namedtuple('Row', colnames, rename=True)._make
983                except TypeError:  # Python 2.6 and 3.0 do not support rename
984                    colnames = [v if v.isalnum() else 'column_%d' % n
985                             for n, v in enumerate(colnames)]
986                    return namedtuple('Row', colnames)._make
987            except ValueError:  # there is still a problem with the field names
988                colnames = ['column_%d' % n for n in range(len(colnames))]
989                return namedtuple('Row', colnames)._make
990
991
992CursorDescription = namedtuple('CursorDescription',
993    ['name', 'type_code', 'display_size', 'internal_size',
994     'precision', 'scale', 'null_ok'])
995
996
997### Connection Objects
998
999class Connection(object):
1000    """Connection object."""
1001
1002    # expose the exceptions as attributes on the connection object
1003    Error = Error
1004    Warning = Warning
1005    InterfaceError = InterfaceError
1006    DatabaseError = DatabaseError
1007    InternalError = InternalError
1008    OperationalError = OperationalError
1009    ProgrammingError = ProgrammingError
1010    IntegrityError = IntegrityError
1011    DataError = DataError
1012    NotSupportedError = NotSupportedError
1013
1014    def __init__(self, cnx):
1015        """Create a database connection object."""
1016        self._cnx = cnx  # connection
1017        self._tnx = False  # transaction state
1018        self.type_cache = TypeCache(cnx)
1019        self.cursor_type = Cursor
1020        try:
1021            self._cnx.source()
1022        except Exception:
1023            raise _op_error("Invalid connection")
1024
1025    def __enter__(self):
1026        """Enter the runtime context for the connection object.
1027
1028        The runtime context can be used for running transactions.
1029        """
1030        return self
1031
1032    def __exit__(self, et, ev, tb):
1033        """Exit the runtime context for the connection object.
1034
1035        This does not close the connection, but it ends a transaction.
1036        """
1037        if et is None and ev is None and tb is None:
1038            self.commit()
1039        else:
1040            self.rollback()
1041
1042    def close(self):
1043        """Close the connection object."""
1044        if self._cnx:
1045            if self._tnx:
1046                try:
1047                    self.rollback()
1048                except DatabaseError:
1049                    pass
1050            self._cnx.close()
1051            self._cnx = None
1052        else:
1053            raise _op_error("Connection has been closed")
1054
1055    def commit(self):
1056        """Commit any pending transaction to the database."""
1057        if self._cnx:
1058            if self._tnx:
1059                self._tnx = False
1060                try:
1061                    self._cnx.source().execute("COMMIT")
1062                except DatabaseError:
1063                    raise
1064                except Exception:
1065                    raise _op_error("Can't commit")
1066        else:
1067            raise _op_error("Connection has been closed")
1068
1069    def rollback(self):
1070        """Roll back to the start of any pending transaction."""
1071        if self._cnx:
1072            if self._tnx:
1073                self._tnx = False
1074                try:
1075                    self._cnx.source().execute("ROLLBACK")
1076                except DatabaseError:
1077                    raise
1078                except Exception:
1079                    raise _op_error("Can't rollback")
1080        else:
1081            raise _op_error("Connection has been closed")
1082
1083    def cursor(self):
1084        """Return a new cursor object using the connection."""
1085        if self._cnx:
1086            try:
1087                return self.cursor_type(self)
1088            except Exception:
1089                raise _op_error("Invalid connection")
1090        else:
1091            raise _op_error("Connection has been closed")
1092
1093    if shortcutmethods:  # otherwise do not implement and document this
1094
1095        def execute(self, operation, params=None):
1096            """Shortcut method to run an operation on an implicit cursor."""
1097            cursor = self.cursor()
1098            cursor.execute(operation, params)
1099            return cursor
1100
1101        def executemany(self, operation, param_seq):
1102            """Shortcut method to run an operation against a sequence."""
1103            cursor = self.cursor()
1104            cursor.executemany(operation, param_seq)
1105            return cursor
1106
1107
1108### Module Interface
1109
1110_connect_ = connect
1111
1112def connect(dsn=None,
1113        user=None, password=None,
1114        host=None, database=None):
1115    """Connect to a database."""
1116    # first get params from DSN
1117    dbport = -1
1118    dbhost = ""
1119    dbbase = ""
1120    dbuser = ""
1121    dbpasswd = ""
1122    dbopt = ""
1123    try:
1124        params = dsn.split(":")
1125        dbhost = params[0]
1126        dbbase = params[1]
1127        dbuser = params[2]
1128        dbpasswd = params[3]
1129        dbopt = params[4]
1130    except (AttributeError, IndexError, TypeError):
1131        pass
1132
1133    # override if necessary
1134    if user is not None:
1135        dbuser = user
1136    if password is not None:
1137        dbpasswd = password
1138    if database is not None:
1139        dbbase = database
1140    if host is not None:
1141        try:
1142            params = host.split(":")
1143            dbhost = params[0]
1144            dbport = int(params[1])
1145        except (AttributeError, IndexError, TypeError, ValueError):
1146            pass
1147
1148    # empty host is localhost
1149    if dbhost == "":
1150        dbhost = None
1151    if dbuser == "":
1152        dbuser = None
1153
1154    # open the connection
1155    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
1156    return Connection(cnx)
1157
1158
1159### Types Handling
1160
1161class Type(frozenset):
1162    """Type class for a couple of PostgreSQL data types.
1163
1164    PostgreSQL is object-oriented: types are dynamic.
1165    We must thus use type names as internal type codes.
1166    """
1167
1168    def __new__(cls, values):
1169        if isinstance(values, basestring):
1170            values = values.split()
1171        return super(Type, cls).__new__(cls, values)
1172
1173    def __eq__(self, other):
1174        if isinstance(other, basestring):
1175            if other.startswith('_'):
1176                other = other[1:]
1177            return other in self
1178        else:
1179            return super(Type, self).__eq__(other)
1180
1181    def __ne__(self, other):
1182        if isinstance(other, basestring):
1183            if other.startswith('_'):
1184                other = other[1:]
1185            return other not in self
1186        else:
1187            return super(Type, self).__ne__(other)
1188
1189
1190class ArrayType:
1191    """Type class for PostgreSQL array types."""
1192
1193    def __eq__(self, other):
1194        if isinstance(other, basestring):
1195            return other.startswith('_')
1196        else:
1197            return isinstance(other, ArrayType)
1198
1199    def __ne__(self, other):
1200        if isinstance(other, basestring):
1201            return not other.startswith('_')
1202        else:
1203            return not isinstance(other, ArrayType)
1204
1205
1206class RecordType:
1207    """Type class for PostgreSQL record types."""
1208
1209    def __eq__(self, other):
1210        if isinstance(other, TypeCode):
1211            return other.type == 'c'
1212        elif isinstance(other, basestring):
1213            return other == 'record'
1214        else:
1215            return isinstance(other, RecordType)
1216
1217    def __ne__(self, other):
1218        if isinstance(other, TypeCode):
1219            return other.type != 'c'
1220        elif isinstance(other, basestring):
1221            return other != 'record'
1222        else:
1223            return not isinstance(other, RecordType)
1224
1225
1226# Mandatory type objects defined by DB-API 2 specs:
1227
1228STRING = Type('char bpchar name text varchar')
1229BINARY = Type('bytea')
1230NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1231DATETIME = Type('date time timetz timestamp timestamptz interval'
1232    ' abstime reltime')  # these are very old
1233ROWID = Type('oid')
1234
1235
1236# Additional type objects (more specific):
1237
1238BOOL = Type('bool')
1239SMALLINT = Type('int2')
1240INTEGER = Type('int2 int4 int8 serial')
1241LONG = Type('int8')
1242FLOAT = Type('float4 float8')
1243NUMERIC = Type('numeric')
1244MONEY = Type('money')
1245DATE = Type('date')
1246TIME = Type('time timetz')
1247TIMESTAMP = Type('timestamp timestamptz')
1248INTERVAL = Type('interval')
1249JSON = Type('json jsonb')
1250
1251# Type object for arrays (also equate to their base types):
1252
1253ARRAY = ArrayType()
1254
1255# Type object for records (encompassing all composite types):
1256
1257RECORD = RecordType()
1258
1259
1260# Mandatory type helpers defined by DB-API 2 specs:
1261
1262def Date(year, month, day):
1263    """Construct an object holding a date value."""
1264    return date(year, month, day)
1265
1266
1267def Time(hour, minute=0, second=0, microsecond=0):
1268    """Construct an object holding a time value."""
1269    return time(hour, minute, second, microsecond)
1270
1271
1272def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1273    """Construct an object holding a time stamp value."""
1274    return datetime(year, month, day, hour, minute, second, microsecond)
1275
1276
1277def DateFromTicks(ticks):
1278    """Construct an object holding a date value from the given ticks value."""
1279    return Date(*localtime(ticks)[:3])
1280
1281
1282def TimeFromTicks(ticks):
1283    """Construct an object holding a time value from the given ticks value."""
1284    return Time(*localtime(ticks)[3:6])
1285
1286
1287def TimestampFromTicks(ticks):
1288    """Construct an object holding a time stamp from the given ticks value."""
1289    return Timestamp(*localtime(ticks)[:6])
1290
1291
1292class Binary(bytes):
1293    """Construct an object capable of holding a binary (long) string value."""
1294
1295
1296# Additional type helpers for PyGreSQL:
1297
1298class Json:
1299    """Construct a wrapper for holding an object serializable to JSON."""
1300
1301    def __init__(self, obj, encode=None):
1302        self.obj = obj
1303        self.encode = encode or jsonencode
1304
1305    def __str__(self):
1306        obj = self.obj
1307        if isinstance(obj, basestring):
1308            return obj
1309        return self.encode(obj)
1310
1311    __pg_repr__ = __str__
1312
1313
1314# If run as script, print some information:
1315
1316if __name__ == '__main__':
1317    print('PyGreSQL version', version)
1318    print('')
1319    print(__doc__)
Note: See TracBrowser for help on using the repository browser.