source: trunk/pgdb.py @ 797

Last change on this file since 797 was 797, checked in by cito, 3 years ago

Cache typecast functions and make them configurable

The typecast functions used by the pgdb module are now cached
using a local and a global Typecasts class. The local cache is
bound to the connection and knows how to cast composite types.

Also added functions that allow registering custom typecast
functions on the global and local level.

Also added a chapter on type adaptation and casting to the docs.

  • Property svn:keywords set to Id
File size: 43.1 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 797 2016-01-29 22:43:22Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from json import loads as jsondecode, dumps as jsonencode
76
77try:
78    long
79except NameError:  # Python >= 3.0
80    long = int
81
82try:
83    unicode
84except NameError:  # Python >= 3.0
85    unicode = str
86
87try:
88    basestring
89except NameError:  # Python >= 3.0
90    basestring = (str, bytes)
91
92from collections import Iterable
93try:
94    from collections import OrderedDict
95except ImportError:  # Python 2.6 or 3.0
96    try:
97        from ordereddict import OrderedDict
98    except Exception:
99        def OrderedDict(*args):
100            raise NotSupportedError('OrderedDict is not supported')
101
102
103### Module Constants
104
105# compliant with DB API 2.0
106apilevel = '2.0'
107
108# module may be shared, but not connections
109threadsafety = 1
110
111# this module use extended python format codes
112paramstyle = 'pyformat'
113
114# shortcut methods have been excluded from DB API 2 and
115# are not recommended by the DB SIG, but they can be handy
116shortcutmethods = 1
117
118
119### Internal Type Handling
120
121def decimal_type(decimal_type=None):
122    """Get or set global type to be used for decimal values.
123
124    Note that connections cache cast functions. To be sure a global change
125    is picked up by a running connection, call con.type_cache.reset_typecast().
126    """
127    global Decimal
128    if decimal_type is not None:
129        Decimal = decimal_type
130        set_typecast('numeric', decimal_type)
131    return Decimal
132
133
134def cast_bool(value):
135    """Cast boolean value in database format to bool."""
136    if value:
137        return value[0] in ('t', 'T')
138
139
140def cast_money(value):
141    """Cast money value in database format to Decimal."""
142    if value:
143        value = value.replace('(', '-')
144        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
145
146
147class Typecasts(dict):
148    """Dictionary mapping database types to typecast functions.
149
150    The cast functions must accept one Python object as an argument and
151    convert that object to a string representation of the corresponding type
152    in the database.  The Python None object is always converted to NULL,
153    so the cast functions can assume they never get passed None as argument.
154    However, they may get passed an empty string or a numeric null value.
155    """
156
157    # the default cast functions
158    # (str functions are ignored but have been added for faster access)
159    defaults = {'char': str, 'bpchar': str, 'name': str,
160        'text': str, 'varchar': str,
161        'bool': cast_bool, 'bytea': unescape_bytea,
162        'int2': int, 'int4': int, 'serial': int,
163        'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
164        'oid': long, 'oid8': long,
165        'float4': float, 'float8': float,
166        'numeric': Decimal, 'money': cast_money,
167        'anyarray': cast_array, 'record': cast_record}
168
169    def __missing__(self, typ):
170        """Create a cast function if it is not cached.
171
172        Note that this class never raises a KeyError,
173        but return None when no special cast function exists.
174        """
175        cast = self.defaults.get(typ)
176        if cast:
177            # store default for faster access
178            self[typ] = cast
179        elif typ.startswith('_'):
180            # create array cast
181            base_cast = self[typ[1:]]
182            cast = self.create_array_cast(base_cast)
183            if base_cast:
184                # store only if base type exists
185                self[typ] = cast
186        return cast
187
188    def get(self, typ, default=None):
189        """Get the typecast function for the given database type."""
190        return self[typ] or default
191
192    def set(self, typ, cast):
193        """Set a typecast function for the specified database type(s)."""
194        if isinstance(typ, basestring):
195            typ = [typ]
196        if cast is None:
197            for t in typ:
198                self.pop(t, None)
199                self.pop('_%s' % t, None)
200        else:
201            if not callable(cast):
202                raise TypeError("Cast parameter must be callable")
203            for t in typ:
204                self[t] = cast
205                self.pop('_%s % t', None)
206
207    def reset(self, typ=None):
208        """Reset the typecasts for the specified type(s) to their defaults.
209
210        When no type is specified, all typecasts will be reset.
211        """
212        defaults = self.defaults
213        if typ is None:
214            self.clear()
215            self.update(defaults)
216        else:
217            if isinstance(typ, basestring):
218                typ = [typ]
219            for t in typ:
220                self.set(t, defaults.get(t))
221
222    def create_array_cast(self, cast):
223        """Create an array typecast for the given base cast."""
224        return lambda v: cast_array(v, cast)
225
226    def create_record_cast(self, name, fields, casts):
227        """Create a named record typecast for the given fields and casts."""
228        record = namedtuple(name, fields)
229        return lambda v: record(*cast_record(v, casts))
230
231
232_typecasts = Typecasts()  # this is the global typecast dictionary
233
234
235def get_typecast(typ):
236    """Get the global typecast function for the given database type(s)."""
237    return _typecasts.get(typ)
238
239
240def set_typecast(typ, cast):
241    """Set a global typecast function for the given database type(s).
242
243    Note that connections cache cast functions. To be sure a global change
244    is picked up by a running connection, call con.type_cache.reset_typecast().
245    """
246    _typecasts.set(typ, cast)
247
248
249def reset_typecast(typ=None):
250    """Reset the global typecasts for the given type(s) to their default.
251
252    When no type is specified, all typecasts will be reset.
253
254    Note that connections cache cast functions. To be sure a global change
255    is picked up by a running connection, call con.type_cache.reset_typecast().
256    """
257    _typecasts.reset(typ)
258
259
260class LocalTypecasts(Typecasts):
261    """Map typecasts, including local composite types, to cast functions."""
262
263    defaults = _typecasts
264
265    def __missing__(self, typ):
266        """Create a cast function if it is not cached."""
267        if typ.startswith('_'):
268            base_cast = self[typ[1:]]
269            cast = self.create_array_cast(base_cast)
270            if base_cast:
271                self[typ] = cast
272        else:
273            cast = self.defaults.get(typ)
274            if cast:
275                self[typ] = cast
276            else:
277                fields = self.get_fields(typ)
278                if fields:
279                    casts = [self[field.type] for field in fields]
280                    fields = [field.name for field in fields]
281                    cast = self.create_record_cast(typ, fields, casts)
282                    self[typ] = cast
283        return cast
284
285    def get_fields(self, typ):
286        """Return the fields for the given record type.
287
288        This method will be replaced with a method that looks up the fields
289        using the type cache of the connection.
290        """
291        return []
292
293
294class TypeCode(str):
295    """Class representing the type_code used by the DB-API 2.0.
296
297    TypeCode objects are strings equal to the PostgreSQL type name,
298    but carry some additional information.
299    """
300
301    @classmethod
302    def create(cls, oid, name, len, type, category, delim, relid):
303        """Create a type code for a PostgreSQL data type."""
304        self = cls(name)
305        self.oid = oid
306        self.len = len
307        self.type = type
308        self.category = category
309        self.delim = delim
310        self.relid = relid
311        return self
312
313FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
314
315
316class TypeCache(dict):
317    """Cache for database types.
318
319    This cache maps type OIDs and names to TypeCode strings containing
320    important information on the associated database type.
321    """
322
323    def __init__(self, cnx):
324        """Initialize type cache for connection."""
325        super(TypeCache, self).__init__()
326        self._escape_string = cnx.escape_string
327        self._src = cnx.source()
328        self._typecasts = LocalTypecasts()
329        self._typecasts.get_fields = self.get_fields
330
331    def __missing__(self, key):
332        """Get the type info from the database if it is not cached."""
333        if isinstance(key, int):
334            oid = key
335        else:
336            if not '.' in key and not '"' in key:
337                key = '"%s"' % key
338            oid = "'%s'::regtype" % self._escape_string(key)
339        try:
340            self._src.execute("SELECT oid, typname,"
341                 " typlen, typtype, typcategory, typdelim, typrelid"
342                " FROM pg_type WHERE oid=%s" % oid)
343        except ProgrammingError:
344            res = None
345        else:
346            res = self._src.fetch(1)
347        if not res:
348            raise KeyError('Type %s could not be found' % key)
349        res = res[0]
350        type_code = TypeCode.create(int(res[0]), res[1],
351            int(res[2]), res[3], res[4], res[5], int(res[6]))
352        self[type_code.oid] = self[str(type_code)] = type_code
353        return type_code
354
355    def get(self, key, default=None):
356        """Get the type even if it is not cached."""
357        try:
358            return self[key]
359        except KeyError:
360            return default
361
362    def get_fields(self, typ):
363        """Get the names and types of the fields of composite types."""
364        if not isinstance(typ, TypeCode):
365            typ = self.get(typ)
366            if not typ:
367                return None
368        if not typ.relid:
369            return None  # this type is not composite
370        self._src.execute("SELECT attname, atttypid"
371            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
372            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
373        return [FieldInfo(name, self.get(int(oid)))
374            for name, oid in self._src.fetch(-1)]
375
376    def get_typecast(self, typ):
377        """Get the typecast function for the given database type."""
378        return self._typecasts.get(typ)
379
380    def set_typecast(self, typ, cast):
381        """Set a typecast function for the specified database type(s)."""
382        self._typecasts.set(typ, cast)
383
384    def reset_typecast(self, typ=None):
385        """Reset the typecast function for the specified database type(s)."""
386        self._typecasts.reset(typ)
387
388    def typecast(self, typ, value):
389        """Cast the given value according to the given database type."""
390        if value is None:
391            # for NULL values, no typecast is necessary
392            return None
393        cast = self.get_typecast(typ)
394        if not cast or cast is str:
395            # no typecast is necessary
396            return value
397        return cast(value)
398
399
400class _quotedict(dict):
401    """Dictionary with auto quoting of its items.
402
403    The quote attribute must be set to the desired quote function.
404    """
405
406    def __getitem__(self, key):
407        return self.quote(super(_quotedict, self).__getitem__(key))
408
409
410### Error messages
411
412def _db_error(msg, cls=DatabaseError):
413    """Return DatabaseError with empty sqlstate attribute."""
414    error = cls(msg)
415    error.sqlstate = None
416    return error
417
418
419def _op_error(msg):
420    """Return OperationalError."""
421    return _db_error(msg, OperationalError)
422
423
424### Cursor Object
425
426class Cursor(object):
427    """Cursor object."""
428
429    def __init__(self, dbcnx):
430        """Create a cursor object for the database connection."""
431        self.connection = self._dbcnx = dbcnx
432        self._cnx = dbcnx._cnx
433        self.type_cache = dbcnx.type_cache
434        self._src = self._cnx.source()
435        # the official attribute for describing the result columns
436        self._description = None
437        if self.row_factory is Cursor.row_factory:
438            # the row factory needs to be determined dynamically
439            self.row_factory = None
440        else:
441            self.build_row_factory = None
442        self.rowcount = -1
443        self.arraysize = 1
444        self.lastrowid = None
445
446    def __iter__(self):
447        """Make cursor compatible to the iteration protocol."""
448        return self
449
450    def __enter__(self):
451        """Enter the runtime context for the cursor object."""
452        return self
453
454    def __exit__(self, et, ev, tb):
455        """Exit the runtime context for the cursor object."""
456        self.close()
457
458    def _quote(self, value):
459        """Quote value depending on its type."""
460        if value is None:
461            return 'NULL'
462        if isinstance(value, (datetime, date, time, timedelta, Json)):
463            value = str(value)
464        if isinstance(value, basestring):
465            if isinstance(value, Binary):
466                value = self._cnx.escape_bytea(value)
467                if bytes is not str:  # Python >= 3.0
468                    value = value.decode('ascii')
469            else:
470                value = self._cnx.escape_string(value)
471            return "'%s'" % value
472        if isinstance(value, float):
473            if isinf(value):
474                return "'-Infinity'" if value < 0 else "'Infinity'"
475            if isnan(value):
476                return "'NaN'"
477            return value
478        if isinstance(value, (int, long, Decimal)):
479            return value
480        if isinstance(value, list):
481            # Quote value as an ARRAY constructor. This is better than using
482            # an array literal because it carries the information that this is
483            # an array and not a string.  One issue with this syntax is that
484            # you need to add an explicit typecast when passing empty arrays.
485            # The ARRAY keyword is actually only necessary at the top level.
486            q = self._quote
487            return 'ARRAY[%s]' % ','.join(str(q(v)) for v in value)
488        if isinstance(value, tuple):
489            # Quote as a ROW constructor.  This is better than using a record
490            # literal because it carries the information that this is a record
491            # and not a string.  We don't use the keyword ROW in order to make
492            # this usable with the IN synntax as well.  It is only necessary
493            # when the records has a single column which is not really useful.
494            q = self._quote
495            return '(%s)' % ','.join(str(q(v)) for v in value)
496        try:
497            value = value.__pg_repr__()
498            if isinstance(value, (tuple, list)):
499                value = self._quote(value)
500            return value
501        except AttributeError:
502            raise InterfaceError(
503                'Do not know how to adapt type %s' % type(value))
504
505    def _quoteparams(self, string, parameters):
506        """Quote parameters.
507
508        This function works for both mappings and sequences.
509        """
510        if isinstance(parameters, dict):
511            parameters = _quotedict(parameters)
512            parameters.quote = self._quote
513        else:
514            parameters = tuple(map(self._quote, parameters))
515        return string % parameters
516
517    def _make_description(self, info):
518        """Make the description tuple for the given field info."""
519        name, typ, size, mod = info[1:]
520        type_code = self.type_cache[typ]
521        if mod > 0:
522            mod -= 4
523        if type_code == 'numeric':
524            precision, scale = mod >> 16, mod & 0xffff
525            size = precision
526        else:
527            if not size:
528                size = type_info.size
529            if size == -1:
530                size = mod
531            precision = scale = None
532        return CursorDescription(name, type_code,
533            None, size, precision, scale, None)
534
535    @property
536    def description(self):
537        """Read-only attribute describing the result columns."""
538        descr = self._description
539        if self._description is True:
540            make = self._make_description
541            descr = [make(info) for info in self._src.listinfo()]
542            self._description = descr
543        return descr
544
545    @property
546    def colnames(self):
547        """Unofficial convenience method for getting the column names."""
548        return [d[0] for d in self.description]
549
550    @property
551    def coltypes(self):
552        """Unofficial convenience method for getting the column types."""
553        return [d[1] for d in self.description]
554
555    def close(self):
556        """Close the cursor object."""
557        self._src.close()
558        self._description = None
559        self.rowcount = -1
560        self.lastrowid = None
561
562    def execute(self, operation, parameters=None):
563        """Prepare and execute a database operation (query or command)."""
564        # The parameters may also be specified as list of tuples to e.g.
565        # insert multiple rows in a single operation, but this kind of
566        # usage is deprecated.  We make several plausibility checks because
567        # tuples can also be passed with the meaning of ROW constructors.
568        if (parameters and isinstance(parameters, list)
569                and len(parameters) > 1
570                and all(isinstance(p, tuple) for p in parameters)
571                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
572            return self.executemany(operation, parameters)
573        else:
574            # not a list of tuples
575            return self.executemany(operation, [parameters])
576
577    def executemany(self, operation, seq_of_parameters):
578        """Prepare operation and execute it against a parameter sequence."""
579        if not seq_of_parameters:
580            # don't do anything without parameters
581            return
582        self._description = None
583        self.rowcount = -1
584        # first try to execute all queries
585        rowcount = 0
586        sql = "BEGIN"
587        try:
588            if not self._dbcnx._tnx:
589                try:
590                    self._cnx.source().execute(sql)
591                except DatabaseError:
592                    raise  # database provides error message
593                except Exception as err:
594                    raise _op_error("Can't start transaction")
595                self._dbcnx._tnx = True
596            for parameters in seq_of_parameters:
597                sql = operation
598                if parameters:
599                    sql = self._quoteparams(sql, parameters)
600                rows = self._src.execute(sql)
601                if rows:  # true if not DML
602                    rowcount += rows
603                else:
604                    self.rowcount = -1
605        except DatabaseError:
606            raise  # database provides error message
607        except Error as err:
608            raise _db_error(
609                "Error in '%s': '%s' " % (sql, err), InterfaceError)
610        except Exception as err:
611            raise _op_error("Internal error in '%s': %s" % (sql, err))
612        # then initialize result raw count and description
613        if self._src.resulttype == RESULT_DQL:
614            self._description = True  # fetch on demand
615            self.rowcount = self._src.ntuples
616            self.lastrowid = None
617            if self.build_row_factory:
618                self.row_factory = self.build_row_factory()
619        else:
620            self.rowcount = rowcount
621            self.lastrowid = self._src.oidstatus()
622        # return the cursor object, so you can write statements such as
623        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
624        return self
625
626    def fetchone(self):
627        """Fetch the next row of a query result set."""
628        res = self.fetchmany(1, False)
629        try:
630            return res[0]
631        except IndexError:
632            return None
633
634    def fetchall(self):
635        """Fetch all (remaining) rows of a query result."""
636        return self.fetchmany(-1, False)
637
638    def fetchmany(self, size=None, keep=False):
639        """Fetch the next set of rows of a query result.
640
641        The number of rows to fetch per call is specified by the
642        size parameter. If it is not given, the cursor's arraysize
643        determines the number of rows to be fetched. If you set
644        the keep parameter to true, this is kept as new arraysize.
645        """
646        if size is None:
647            size = self.arraysize
648        if keep:
649            self.arraysize = size
650        try:
651            result = self._src.fetch(size)
652        except DatabaseError:
653            raise
654        except Error as err:
655            raise _db_error(str(err))
656        typecast = self.type_cache.typecast
657        return [self.row_factory([typecast(typ, value)
658            for typ, value in zip(self.coltypes, row)]) for row in result]
659
660    def callproc(self, procname, parameters=None):
661        """Call a stored database procedure with the given name.
662
663        The sequence of parameters must contain one entry for each input
664        argument that the procedure expects. The result of the call is the
665        same as this input sequence; replacement of output and input/output
666        parameters in the return value is currently not supported.
667
668        The procedure may also provide a result set as output. These can be
669        requested through the standard fetch methods of the cursor.
670        """
671        n = parameters and len(parameters) or 0
672        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
673        self.execute(query, parameters)
674        return parameters
675
676    def copy_from(self, stream, table,
677            format=None, sep=None, null=None, size=None, columns=None):
678        """Copy data from an input stream to the specified table.
679
680        The input stream can be a file-like object with a read() method or
681        it can also be an iterable returning a row or multiple rows of input
682        on each iteration.
683
684        The format must be text, csv or binary. The sep option sets the
685        column separator (delimiter) used in the non binary formats.
686        The null option sets the textual representation of NULL in the input.
687
688        The size option sets the size of the buffer used when reading data
689        from file-like objects.
690
691        The copy operation can be restricted to a subset of columns. If no
692        columns are specified, all of them will be copied.
693        """
694        binary_format = format == 'binary'
695        try:
696            read = stream.read
697        except AttributeError:
698            if size:
699                raise ValueError("Size must only be set for file-like objects")
700            if binary_format:
701                input_type = bytes
702                type_name = 'byte strings'
703            else:
704                input_type = basestring
705                type_name = 'strings'
706
707            if isinstance(stream, basestring):
708                if not isinstance(stream, input_type):
709                    raise ValueError("The input must be %s" % type_name)
710                if not binary_format:
711                    if isinstance(stream, str):
712                        if not stream.endswith('\n'):
713                            stream += '\n'
714                    else:
715                        if not stream.endswith(b'\n'):
716                            stream += b'\n'
717
718                def chunks():
719                    yield stream
720
721            elif isinstance(stream, Iterable):
722
723                def chunks():
724                    for chunk in stream:
725                        if not isinstance(chunk, input_type):
726                            raise ValueError(
727                                "Input stream must consist of %s" % type_name)
728                        if isinstance(chunk, str):
729                            if not chunk.endswith('\n'):
730                                chunk += '\n'
731                        else:
732                            if not chunk.endswith(b'\n'):
733                                chunk += b'\n'
734                        yield chunk
735
736            else:
737                raise TypeError("Need an input stream to copy from")
738        else:
739            if size is None:
740                size = 8192
741            elif not isinstance(size, int):
742                raise TypeError("The size option must be an integer")
743            if size > 0:
744
745                def chunks():
746                    while True:
747                        buffer = read(size)
748                        yield buffer
749                        if not buffer or len(buffer) < size:
750                            break
751
752            else:
753
754                def chunks():
755                    yield read()
756
757        if not table or not isinstance(table, basestring):
758            raise TypeError("Need a table to copy to")
759        if table.lower().startswith('select'):
760                raise ValueError("Must specify a table, not a query")
761        else:
762            table = '"%s"' % (table,)
763        operation = ['copy %s' % (table,)]
764        options = []
765        params = []
766        if format is not None:
767            if not isinstance(format, basestring):
768                raise TypeError("The frmat option must be be a string")
769            if format not in ('text', 'csv', 'binary'):
770                raise ValueError("Invalid format")
771            options.append('format %s' % (format,))
772        if sep is not None:
773            if not isinstance(sep, basestring):
774                raise TypeError("The sep option must be a string")
775            if format == 'binary':
776                raise ValueError(
777                    "The sep option is not allowed with binary format")
778            if len(sep) != 1:
779                raise ValueError(
780                    "The sep option must be a single one-byte character")
781            options.append('delimiter %s')
782            params.append(sep)
783        if null is not None:
784            if not isinstance(null, basestring):
785                raise TypeError("The null option must be a string")
786            options.append('null %s')
787            params.append(null)
788        if columns:
789            if not isinstance(columns, basestring):
790                columns = ','.join('"%s"' % (col,) for col in columns)
791            operation.append('(%s)' % (columns,))
792        operation.append("from stdin")
793        if options:
794            operation.append('(%s)' % ','.join(options))
795        operation = ' '.join(operation)
796
797        putdata = self._src.putdata
798        self.execute(operation, params)
799
800        try:
801            for chunk in chunks():
802                putdata(chunk)
803        except BaseException as error:
804            self.rowcount = -1
805            # the following call will re-raise the error
806            putdata(error)
807        else:
808            self.rowcount = putdata(None)
809
810        # return the cursor object, so you can chain operations
811        return self
812
813    def copy_to(self, stream, table,
814            format=None, sep=None, null=None, decode=None, columns=None):
815        """Copy data from the specified table to an output stream.
816
817        The output stream can be a file-like object with a write() method or
818        it can also be None, in which case the method will return a generator
819        yielding a row on each iteration.
820
821        Output will be returned as byte strings unless you set decode to true.
822
823        Note that you can also use a select query instead of the table name.
824
825        The format must be text, csv or binary. The sep option sets the
826        column separator (delimiter) used in the non binary formats.
827        The null option sets the textual representation of NULL in the output.
828
829        The copy operation can be restricted to a subset of columns. If no
830        columns are specified, all of them will be copied.
831        """
832        binary_format = format == 'binary'
833        if stream is not None:
834            try:
835                write = stream.write
836            except AttributeError:
837                raise TypeError("Need an output stream to copy to")
838        if not table or not isinstance(table, basestring):
839            raise TypeError("Need a table to copy to")
840        if table.lower().startswith('select'):
841            if columns:
842                raise ValueError("Columns must be specified in the query")
843            table = '(%s)' % (table,)
844        else:
845            table = '"%s"' % (table,)
846        operation = ['copy %s' % (table,)]
847        options = []
848        params = []
849        if format is not None:
850            if not isinstance(format, basestring):
851                raise TypeError("The format option must be a string")
852            if format not in ('text', 'csv', 'binary'):
853                raise ValueError("Invalid format")
854            options.append('format %s' % (format,))
855        if sep is not None:
856            if not isinstance(sep, basestring):
857                raise TypeError("The sep option must be a string")
858            if binary_format:
859                raise ValueError(
860                    "The sep option is not allowed with binary format")
861            if len(sep) != 1:
862                raise ValueError(
863                    "The sep option must be a single one-byte character")
864            options.append('delimiter %s')
865            params.append(sep)
866        if null is not None:
867            if not isinstance(null, basestring):
868                raise TypeError("The null option must be a string")
869            options.append('null %s')
870            params.append(null)
871        if decode is None:
872            if format == 'binary':
873                decode = False
874            else:
875                decode = str is unicode
876        else:
877            if not isinstance(decode, (int, bool)):
878                raise TypeError("The decode option must be a boolean")
879            if decode and binary_format:
880                raise ValueError(
881                    "The decode option is not allowed with binary format")
882        if columns:
883            if not isinstance(columns, basestring):
884                columns = ','.join('"%s"' % (col,) for col in columns)
885            operation.append('(%s)' % (columns,))
886
887        operation.append("to stdout")
888        if options:
889            operation.append('(%s)' % ','.join(options))
890        operation = ' '.join(operation)
891
892        getdata = self._src.getdata
893        self.execute(operation, params)
894
895        def copy():
896            self.rowcount = 0
897            while True:
898                row = getdata(decode)
899                if isinstance(row, int):
900                    if self.rowcount != row:
901                        self.rowcount = row
902                    break
903                self.rowcount += 1
904                yield row
905
906        if stream is None:
907            # no input stream, return the generator
908            return copy()
909
910        # write the rows to the file-like input stream
911        for row in copy():
912            write(row)
913
914        # return the cursor object, so you can chain operations
915        return self
916
917    def __next__(self):
918        """Return the next row (support for the iteration protocol)."""
919        res = self.fetchone()
920        if res is None:
921            raise StopIteration
922        return res
923
924    # Note that since Python 2.6 the iterator protocol uses __next()__
925    # instead of next(), we keep it only for backward compatibility of pgdb.
926    next = __next__
927
928    @staticmethod
929    def nextset():
930        """Not supported."""
931        raise NotSupportedError("The nextset() method is not supported")
932
933    @staticmethod
934    def setinputsizes(sizes):
935        """Not supported."""
936        pass  # unsupported, but silently passed
937
938    @staticmethod
939    def setoutputsize(size, column=0):
940        """Not supported."""
941        pass  # unsupported, but silently passed
942
943    @staticmethod
944    def row_factory(row):
945        """Process rows before they are returned.
946
947        You can overwrite this statically with a custom row factory, or
948        you can build a row factory dynamically with build_row_factory().
949
950        For example, you can create a Cursor class that returns rows as
951        Python dictionaries like this:
952
953            class DictCursor(pgdb.Cursor):
954
955                def row_factory(self, row):
956                    return {desc[0]: value
957                        for desc, value in zip(self.description, row)}
958
959            cur = DictCursor(con)  # get one DictCursor instance or
960            con.cursor_type = DictCursor  # always use DictCursor instances
961        """
962        raise NotImplementedError
963
964    def build_row_factory(self):
965        """Build a row factory based on the current description.
966
967        This implementation builds a row factory for creating named tuples.
968        You can overwrite this method if you want to dynamically create
969        different row factories whenever the column description changes.
970        """
971        colnames = self.colnames
972        if colnames:
973            try:
974                try:
975                    return namedtuple('Row', colnames, rename=True)._make
976                except TypeError:  # Python 2.6 and 3.0 do not support rename
977                    colnames = [v if v.isalnum() else 'column_%d' % n
978                             for n, v in enumerate(colnames)]
979                    return namedtuple('Row', colnames)._make
980            except ValueError:  # there is still a problem with the field names
981                colnames = ['column_%d' % n for n in range(len(colnames))]
982                return namedtuple('Row', colnames)._make
983
984
985CursorDescription = namedtuple('CursorDescription',
986    ['name', 'type_code', 'display_size', 'internal_size',
987     'precision', 'scale', 'null_ok'])
988
989
990### Connection Objects
991
992class Connection(object):
993    """Connection object."""
994
995    # expose the exceptions as attributes on the connection object
996    Error = Error
997    Warning = Warning
998    InterfaceError = InterfaceError
999    DatabaseError = DatabaseError
1000    InternalError = InternalError
1001    OperationalError = OperationalError
1002    ProgrammingError = ProgrammingError
1003    IntegrityError = IntegrityError
1004    DataError = DataError
1005    NotSupportedError = NotSupportedError
1006
1007    def __init__(self, cnx):
1008        """Create a database connection object."""
1009        self._cnx = cnx  # connection
1010        self._tnx = False  # transaction state
1011        self.type_cache = TypeCache(cnx)
1012        self.cursor_type = Cursor
1013        try:
1014            self._cnx.source()
1015        except Exception:
1016            raise _op_error("Invalid connection")
1017
1018    def __enter__(self):
1019        """Enter the runtime context for the connection object.
1020
1021        The runtime context can be used for running transactions.
1022        """
1023        return self
1024
1025    def __exit__(self, et, ev, tb):
1026        """Exit the runtime context for the connection object.
1027
1028        This does not close the connection, but it ends a transaction.
1029        """
1030        if et is None and ev is None and tb is None:
1031            self.commit()
1032        else:
1033            self.rollback()
1034
1035    def close(self):
1036        """Close the connection object."""
1037        if self._cnx:
1038            if self._tnx:
1039                try:
1040                    self.rollback()
1041                except DatabaseError:
1042                    pass
1043            self._cnx.close()
1044            self._cnx = None
1045        else:
1046            raise _op_error("Connection has been closed")
1047
1048    def commit(self):
1049        """Commit any pending transaction to the database."""
1050        if self._cnx:
1051            if self._tnx:
1052                self._tnx = False
1053                try:
1054                    self._cnx.source().execute("COMMIT")
1055                except DatabaseError:
1056                    raise
1057                except Exception:
1058                    raise _op_error("Can't commit")
1059        else:
1060            raise _op_error("Connection has been closed")
1061
1062    def rollback(self):
1063        """Roll back to the start of any pending transaction."""
1064        if self._cnx:
1065            if self._tnx:
1066                self._tnx = False
1067                try:
1068                    self._cnx.source().execute("ROLLBACK")
1069                except DatabaseError:
1070                    raise
1071                except Exception:
1072                    raise _op_error("Can't rollback")
1073        else:
1074            raise _op_error("Connection has been closed")
1075
1076    def cursor(self):
1077        """Return a new cursor object using the connection."""
1078        if self._cnx:
1079            try:
1080                return self.cursor_type(self)
1081            except Exception:
1082                raise _op_error("Invalid connection")
1083        else:
1084            raise _op_error("Connection has been closed")
1085
1086    if shortcutmethods:  # otherwise do not implement and document this
1087
1088        def execute(self, operation, params=None):
1089            """Shortcut method to run an operation on an implicit cursor."""
1090            cursor = self.cursor()
1091            cursor.execute(operation, params)
1092            return cursor
1093
1094        def executemany(self, operation, param_seq):
1095            """Shortcut method to run an operation against a sequence."""
1096            cursor = self.cursor()
1097            cursor.executemany(operation, param_seq)
1098            return cursor
1099
1100
1101### Module Interface
1102
1103_connect_ = connect
1104
1105def connect(dsn=None,
1106        user=None, password=None,
1107        host=None, database=None):
1108    """Connect to a database."""
1109    # first get params from DSN
1110    dbport = -1
1111    dbhost = ""
1112    dbbase = ""
1113    dbuser = ""
1114    dbpasswd = ""
1115    dbopt = ""
1116    try:
1117        params = dsn.split(":")
1118        dbhost = params[0]
1119        dbbase = params[1]
1120        dbuser = params[2]
1121        dbpasswd = params[3]
1122        dbopt = params[4]
1123    except (AttributeError, IndexError, TypeError):
1124        pass
1125
1126    # override if necessary
1127    if user is not None:
1128        dbuser = user
1129    if password is not None:
1130        dbpasswd = password
1131    if database is not None:
1132        dbbase = database
1133    if host is not None:
1134        try:
1135            params = host.split(":")
1136            dbhost = params[0]
1137            dbport = int(params[1])
1138        except (AttributeError, IndexError, TypeError, ValueError):
1139            pass
1140
1141    # empty host is localhost
1142    if dbhost == "":
1143        dbhost = None
1144    if dbuser == "":
1145        dbuser = None
1146
1147    # open the connection
1148    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
1149    return Connection(cnx)
1150
1151
1152### Types Handling
1153
1154class Type(frozenset):
1155    """Type class for a couple of PostgreSQL data types.
1156
1157    PostgreSQL is object-oriented: types are dynamic.
1158    We must thus use type names as internal type codes.
1159    """
1160
1161    def __new__(cls, values):
1162        if isinstance(values, basestring):
1163            values = values.split()
1164        return super(Type, cls).__new__(cls, values)
1165
1166    def __eq__(self, other):
1167        if isinstance(other, basestring):
1168            if other.startswith('_'):
1169                other = other[1:]
1170            return other in self
1171        else:
1172            return super(Type, self).__eq__(other)
1173
1174    def __ne__(self, other):
1175        if isinstance(other, basestring):
1176            if other.startswith('_'):
1177                other = other[1:]
1178            return other not in self
1179        else:
1180            return super(Type, self).__ne__(other)
1181
1182
1183class ArrayType:
1184    """Type class for PostgreSQL array types."""
1185
1186    def __eq__(self, other):
1187        if isinstance(other, basestring):
1188            return other.startswith('_')
1189        else:
1190            return isinstance(other, ArrayType)
1191
1192    def __ne__(self, other):
1193        if isinstance(other, basestring):
1194            return not other.startswith('_')
1195        else:
1196            return not isinstance(other, ArrayType)
1197
1198
1199class RecordType:
1200    """Type class for PostgreSQL record types."""
1201
1202    def __eq__(self, other):
1203        if isinstance(other, TypeCode):
1204            return other.type == 'c'
1205        elif isinstance(other, basestring):
1206            return other == 'record'
1207        else:
1208            return isinstance(other, RecordType)
1209
1210    def __ne__(self, other):
1211        if isinstance(other, TypeCode):
1212            return other.type != 'c'
1213        elif isinstance(other, basestring):
1214            return other != 'record'
1215        else:
1216            return not isinstance(other, RecordType)
1217
1218
1219# Mandatory type objects defined by DB-API 2 specs:
1220
1221STRING = Type('char bpchar name text varchar')
1222BINARY = Type('bytea')
1223NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1224DATETIME = Type('date time timetz timestamp timestamptz interval'
1225    ' abstime reltime')  # these are very old
1226ROWID = Type('oid')
1227
1228
1229# Additional type objects (more specific):
1230
1231BOOL = Type('bool')
1232SMALLINT = Type('int2')
1233INTEGER = Type('int2 int4 int8 serial')
1234LONG = Type('int8')
1235FLOAT = Type('float4 float8')
1236NUMERIC = Type('numeric')
1237MONEY = Type('money')
1238DATE = Type('date')
1239TIME = Type('time timetz')
1240TIMESTAMP = Type('timestamp timestamptz')
1241INTERVAL = Type('interval')
1242JSON = Type('json jsonb')
1243
1244# Type object for arrays (also equate to their base types):
1245
1246ARRAY = ArrayType()
1247
1248# Type object for records (encompassing all composite types):
1249
1250RECORD = RecordType()
1251
1252
1253# Mandatory type helpers defined by DB-API 2 specs:
1254
1255def Date(year, month, day):
1256    """Construct an object holding a date value."""
1257    return date(year, month, day)
1258
1259
1260def Time(hour, minute=0, second=0, microsecond=0):
1261    """Construct an object holding a time value."""
1262    return time(hour, minute, second, microsecond)
1263
1264
1265def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1266    """Construct an object holding a time stamp value."""
1267    return datetime(year, month, day, hour, minute, second, microsecond)
1268
1269
1270def DateFromTicks(ticks):
1271    """Construct an object holding a date value from the given ticks value."""
1272    return Date(*localtime(ticks)[:3])
1273
1274
1275def TimeFromTicks(ticks):
1276    """Construct an object holding a time value from the given ticks value."""
1277    return Time(*localtime(ticks)[3:6])
1278
1279
1280def TimestampFromTicks(ticks):
1281    """Construct an object holding a time stamp from the given ticks value."""
1282    return Timestamp(*localtime(ticks)[:6])
1283
1284
1285class Binary(bytes):
1286    """Construct an object capable of holding a binary (long) string value."""
1287
1288
1289# Additional type helpers for PyGreSQL:
1290
1291class Json:
1292    """Construct a wrapper for holding an object serializable to JSON."""
1293
1294    def __init__(self, obj, encode=None):
1295        self.obj = obj
1296        self.encode = encode or jsonencode
1297
1298    def __str__(self):
1299        obj = self.obj
1300        if isinstance(obj, basestring):
1301            return obj
1302        return self.encode(obj)
1303
1304    __pg_repr__ = __str__
1305
1306
1307# If run as script, print some information:
1308
1309if __name__ == '__main__':
1310    print('PyGreSQL version', version)
1311    print('')
1312    print(__doc__)
Note: See TracBrowser for help on using the repository browser.