source: trunk/pgdb.py @ 801

Last change on this file since 801 was 799, checked in by cito, 4 years ago

Improve adaptation and add query_formatted() method

Also added more tests and documentation.

  • Property svn:keywords set to Id
File size: 43.9 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 799 2016-01-31 18:45:58Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from json import loads as jsondecode, dumps as jsonencode
76
77try:
78    long
79except NameError:  # Python >= 3.0
80    long = int
81
82try:
83    unicode
84except NameError:  # Python >= 3.0
85    unicode = str
86
87try:
88    basestring
89except NameError:  # Python >= 3.0
90    basestring = (str, bytes)
91
92from collections import Iterable
93try:
94    from collections import OrderedDict
95except ImportError:  # Python 2.6 or 3.0
96    try:
97        from ordereddict import OrderedDict
98    except Exception:
99        def OrderedDict(*args):
100            raise NotSupportedError('OrderedDict is not supported')
101
102
103### Module Constants
104
105# compliant with DB API 2.0
106apilevel = '2.0'
107
108# module may be shared, but not connections
109threadsafety = 1
110
111# this module use extended python format codes
112paramstyle = 'pyformat'
113
114# shortcut methods have been excluded from DB API 2 and
115# are not recommended by the DB SIG, but they can be handy
116shortcutmethods = 1
117
118
119### Internal Type Handling
120
121def decimal_type(decimal_type=None):
122    """Get or set global type to be used for decimal values.
123
124    Note that connections cache cast functions. To be sure a global change
125    is picked up by a running connection, call con.type_cache.reset_typecast().
126    """
127    global Decimal
128    if decimal_type is not None:
129        Decimal = decimal_type
130        set_typecast('numeric', decimal_type)
131    return Decimal
132
133
134def cast_bool(value):
135    """Cast boolean value in database format to bool."""
136    if value:
137        return value[0] in ('t', 'T')
138
139
140def cast_money(value):
141    """Cast money value in database format to Decimal."""
142    if value:
143        value = value.replace('(', '-')
144        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
145
146
147def cast_int2vector(value):
148    """Cast an int2vector value."""
149    return [int(v) for v in value.split()]
150
151
152class Typecasts(dict):
153    """Dictionary mapping database types to typecast functions.
154
155    The cast functions get passed the string representation of a value in
156    the database which they need to convert to a Python object.  The
157    passed string will never be None since NULL values are already be
158    handled before the cast function is called.
159    """
160
161    # the default cast functions
162    # (str functions are ignored but have been added for faster access)
163    defaults = {'char': str, 'bpchar': str, 'name': str,
164        'text': str, 'varchar': str,
165        'bool': cast_bool, 'bytea': unescape_bytea,
166        'int2': int, 'int4': int, 'serial': int,
167        'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
168        'oid': long, 'oid8': long,
169        'float4': float, 'float8': float,
170        'numeric': Decimal, 'money': cast_money,
171        'int2vector': cast_int2vector,
172        'anyarray': cast_array, 'record': cast_record}
173
174    def __missing__(self, typ):
175        """Create a cast function if it is not cached.
176
177        Note that this class never raises a KeyError,
178        but returns None when no special cast function exists.
179        """
180        if not isinstance(typ, str):
181            raise TypeError('Invalid type: %s' % typ)
182        cast = self.defaults.get(typ)
183        if cast:
184            # store default for faster access
185            self[typ] = cast
186        elif typ.startswith('_'):
187            # create array cast
188            base_cast = self[typ[1:]]
189            cast = self.create_array_cast(base_cast)
190            if base_cast:
191                # store only if base type exists
192                self[typ] = cast
193        return cast
194
195    def get(self, typ, default=None):
196        """Get the typecast function for the given database type."""
197        return self[typ] or default
198
199    def set(self, typ, cast):
200        """Set a typecast function for the specified database type(s)."""
201        if isinstance(typ, basestring):
202            typ = [typ]
203        if cast is None:
204            for t in typ:
205                self.pop(t, None)
206                self.pop('_%s' % t, None)
207        else:
208            if not callable(cast):
209                raise TypeError("Cast parameter must be callable")
210            for t in typ:
211                self[t] = cast
212                self.pop('_%s' % t, None)
213
214    def reset(self, typ=None):
215        """Reset the typecasts for the specified type(s) to their defaults.
216
217        When no type is specified, all typecasts will be reset.
218        """
219        defaults = self.defaults
220        if typ is None:
221            self.clear()
222            self.update(defaults)
223        else:
224            if isinstance(typ, basestring):
225                typ = [typ]
226            for t in typ:
227                cast = defaults.get(t)
228                if cast:
229                    self[t] = cast
230                    t = '_%s' % t
231                    cast = defaults.get(t)
232                    if cast:
233                        self[t] = cast
234                    else:
235                        self.pop(t, None)
236                else:
237                    self.pop(t, None)
238                    self.pop('_%s' % t, None)
239
240    def create_array_cast(self, cast):
241        """Create an array typecast for the given base cast."""
242        return lambda v: cast_array(v, cast)
243
244    def create_record_cast(self, name, fields, casts):
245        """Create a named record typecast for the given fields and casts."""
246        record = namedtuple(name, fields)
247        return lambda v: record(*cast_record(v, casts))
248
249
250_typecasts = Typecasts()  # this is the global typecast dictionary
251
252
253def get_typecast(typ):
254    """Get the global typecast function for the given database type(s)."""
255    return _typecasts.get(typ)
256
257
258def set_typecast(typ, cast):
259    """Set a global typecast function for the given database type(s).
260
261    Note that connections cache cast functions. To be sure a global change
262    is picked up by a running connection, call con.type_cache.reset_typecast().
263    """
264    _typecasts.set(typ, cast)
265
266
267def reset_typecast(typ=None):
268    """Reset the global typecasts for the given type(s) to their default.
269
270    When no type is specified, all typecasts will be reset.
271
272    Note that connections cache cast functions. To be sure a global change
273    is picked up by a running connection, call con.type_cache.reset_typecast().
274    """
275    _typecasts.reset(typ)
276
277
278class LocalTypecasts(Typecasts):
279    """Map typecasts, including local composite types, to cast functions."""
280
281    defaults = _typecasts
282
283    def __missing__(self, typ):
284        """Create a cast function if it is not cached."""
285        if typ.startswith('_'):
286            base_cast = self[typ[1:]]
287            cast = self.create_array_cast(base_cast)
288            if base_cast:
289                self[typ] = cast
290        else:
291            cast = self.defaults.get(typ)
292            if cast:
293                self[typ] = cast
294            else:
295                fields = self.get_fields(typ)
296                if fields:
297                    casts = [self[field.type] for field in fields]
298                    fields = [field.name for field in fields]
299                    cast = self.create_record_cast(typ, fields, casts)
300                    self[typ] = cast
301        return cast
302
303    def get_fields(self, typ):
304        """Return the fields for the given record type.
305
306        This method will be replaced with a method that looks up the fields
307        using the type cache of the connection.
308        """
309        return []
310
311
312class TypeCode(str):
313    """Class representing the type_code used by the DB-API 2.0.
314
315    TypeCode objects are strings equal to the PostgreSQL type name,
316    but carry some additional information.
317    """
318
319    @classmethod
320    def create(cls, oid, name, len, type, category, delim, relid):
321        """Create a type code for a PostgreSQL data type."""
322        self = cls(name)
323        self.oid = oid
324        self.len = len
325        self.type = type
326        self.category = category
327        self.delim = delim
328        self.relid = relid
329        return self
330
331FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
332
333
334class TypeCache(dict):
335    """Cache for database types.
336
337    This cache maps type OIDs and names to TypeCode strings containing
338    important information on the associated database type.
339    """
340
341    def __init__(self, cnx):
342        """Initialize type cache for connection."""
343        super(TypeCache, self).__init__()
344        self._escape_string = cnx.escape_string
345        self._src = cnx.source()
346        self._typecasts = LocalTypecasts()
347        self._typecasts.get_fields = self.get_fields
348
349    def __missing__(self, key):
350        """Get the type info from the database if it is not cached."""
351        if isinstance(key, int):
352            oid = key
353        else:
354            if '.' not in key and '"' not in key:
355                key = '"%s"' % key
356            oid = "'%s'::regtype" % self._escape_string(key)
357        try:
358            self._src.execute("SELECT oid, typname,"
359                 " typlen, typtype, typcategory, typdelim, typrelid"
360                " FROM pg_type WHERE oid=%s" % oid)
361        except ProgrammingError:
362            res = None
363        else:
364            res = self._src.fetch(1)
365        if not res:
366            raise KeyError('Type %s could not be found' % key)
367        res = res[0]
368        type_code = TypeCode.create(int(res[0]), res[1],
369            int(res[2]), res[3], res[4], res[5], int(res[6]))
370        self[type_code.oid] = self[str(type_code)] = type_code
371        return type_code
372
373    def get(self, key, default=None):
374        """Get the type even if it is not cached."""
375        try:
376            return self[key]
377        except KeyError:
378            return default
379
380    def get_fields(self, typ):
381        """Get the names and types of the fields of composite types."""
382        if not isinstance(typ, TypeCode):
383            typ = self.get(typ)
384            if not typ:
385                return None
386        if not typ.relid:
387            return None  # this type is not composite
388        self._src.execute("SELECT attname, atttypid"
389            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
390            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
391        return [FieldInfo(name, self.get(int(oid)))
392            for name, oid in self._src.fetch(-1)]
393
394    def get_typecast(self, typ):
395        """Get the typecast function for the given database type."""
396        return self._typecasts.get(typ)
397
398    def set_typecast(self, typ, cast):
399        """Set a typecast function for the specified database type(s)."""
400        self._typecasts.set(typ, cast)
401
402    def reset_typecast(self, typ=None):
403        """Reset the typecast function for the specified database type(s)."""
404        self._typecasts.reset(typ)
405
406    def typecast(self, value, typ):
407        """Cast the given value according to the given database type."""
408        if value is None:
409            # for NULL values, no typecast is necessary
410            return None
411        cast = self.get_typecast(typ)
412        if not cast or cast is str:
413            # no typecast is necessary
414            return value
415        return cast(value)
416
417
418class _quotedict(dict):
419    """Dictionary with auto quoting of its items.
420
421    The quote attribute must be set to the desired quote function.
422    """
423
424    def __getitem__(self, key):
425        return self.quote(super(_quotedict, self).__getitem__(key))
426
427
428### Error messages
429
430def _db_error(msg, cls=DatabaseError):
431    """Return DatabaseError with empty sqlstate attribute."""
432    error = cls(msg)
433    error.sqlstate = None
434    return error
435
436
437def _op_error(msg):
438    """Return OperationalError."""
439    return _db_error(msg, OperationalError)
440
441
442### Cursor Object
443
444class Cursor(object):
445    """Cursor object."""
446
447    def __init__(self, dbcnx):
448        """Create a cursor object for the database connection."""
449        self.connection = self._dbcnx = dbcnx
450        self._cnx = dbcnx._cnx
451        self.type_cache = dbcnx.type_cache
452        self._src = self._cnx.source()
453        # the official attribute for describing the result columns
454        self._description = None
455        if self.row_factory is Cursor.row_factory:
456            # the row factory needs to be determined dynamically
457            self.row_factory = None
458        else:
459            self.build_row_factory = None
460        self.rowcount = -1
461        self.arraysize = 1
462        self.lastrowid = None
463
464    def __iter__(self):
465        """Make cursor compatible to the iteration protocol."""
466        return self
467
468    def __enter__(self):
469        """Enter the runtime context for the cursor object."""
470        return self
471
472    def __exit__(self, et, ev, tb):
473        """Exit the runtime context for the cursor object."""
474        self.close()
475
476    def _quote(self, value):
477        """Quote value depending on its type."""
478        if value is None:
479            return 'NULL'
480        if isinstance(value, (datetime, date, time, timedelta, Json)):
481            value = str(value)
482        if isinstance(value, basestring):
483            if isinstance(value, Binary):
484                value = self._cnx.escape_bytea(value)
485                if bytes is not str:  # Python >= 3.0
486                    value = value.decode('ascii')
487            else:
488                value = self._cnx.escape_string(value)
489            return "'%s'" % value
490        if isinstance(value, float):
491            if isinf(value):
492                return "'-Infinity'" if value < 0 else "'Infinity'"
493            if isnan(value):
494                return "'NaN'"
495            return value
496        if isinstance(value, (int, long, Decimal, Literal)):
497            return value
498        if isinstance(value, list):
499            # Quote value as an ARRAY constructor. This is better than using
500            # an array literal because it carries the information that this is
501            # an array and not a string.  One issue with this syntax is that
502            # you need to add an explicit typecast when passing empty arrays.
503            # The ARRAY keyword is actually only necessary at the top level.
504            q = self._quote
505            return 'ARRAY[%s]' % ','.join(str(q(v)) for v in value)
506        if isinstance(value, tuple):
507            # Quote as a ROW constructor.  This is better than using a record
508            # literal because it carries the information that this is a record
509            # and not a string.  We don't use the keyword ROW in order to make
510            # this usable with the IN syntax as well.  It is only necessary
511            # when the records has a single column which is not really useful.
512            q = self._quote
513            return '(%s)' % ','.join(str(q(v)) for v in value)
514        try:
515            value = value.__pg_repr__()
516        except AttributeError:
517            raise InterfaceError(
518                'Do not know how to adapt type %s' % type(value))
519        if isinstance(value, (tuple, list)):
520            value = self._quote(value)
521        return value
522
523    def _quoteparams(self, string, parameters):
524        """Quote parameters.
525
526        This function works for both mappings and sequences.
527        """
528        if isinstance(parameters, dict):
529            parameters = _quotedict(parameters)
530            parameters.quote = self._quote
531        else:
532            parameters = tuple(map(self._quote, parameters))
533        return string % parameters
534
535    def _make_description(self, info):
536        """Make the description tuple for the given field info."""
537        name, typ, size, mod = info[1:]
538        type_code = self.type_cache[typ]
539        if mod > 0:
540            mod -= 4
541        if type_code == 'numeric':
542            precision, scale = mod >> 16, mod & 0xffff
543            size = precision
544        else:
545            if not size:
546                size = type_info.size
547            if size == -1:
548                size = mod
549            precision = scale = None
550        return CursorDescription(name, type_code,
551            None, size, precision, scale, None)
552
553    @property
554    def description(self):
555        """Read-only attribute describing the result columns."""
556        descr = self._description
557        if self._description is True:
558            make = self._make_description
559            descr = [make(info) for info in self._src.listinfo()]
560            self._description = descr
561        return descr
562
563    @property
564    def colnames(self):
565        """Unofficial convenience method for getting the column names."""
566        return [d[0] for d in self.description]
567
568    @property
569    def coltypes(self):
570        """Unofficial convenience method for getting the column types."""
571        return [d[1] for d in self.description]
572
573    def close(self):
574        """Close the cursor object."""
575        self._src.close()
576        self._description = None
577        self.rowcount = -1
578        self.lastrowid = None
579
580    def execute(self, operation, parameters=None):
581        """Prepare and execute a database operation (query or command)."""
582        # The parameters may also be specified as list of tuples to e.g.
583        # insert multiple rows in a single operation, but this kind of
584        # usage is deprecated.  We make several plausibility checks because
585        # tuples can also be passed with the meaning of ROW constructors.
586        if (parameters and isinstance(parameters, list)
587                and len(parameters) > 1
588                and all(isinstance(p, tuple) for p in parameters)
589                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
590            return self.executemany(operation, parameters)
591        else:
592            # not a list of tuples
593            return self.executemany(operation, [parameters])
594
595    def executemany(self, operation, seq_of_parameters):
596        """Prepare operation and execute it against a parameter sequence."""
597        if not seq_of_parameters:
598            # don't do anything without parameters
599            return
600        self._description = None
601        self.rowcount = -1
602        # first try to execute all queries
603        rowcount = 0
604        sql = "BEGIN"
605        try:
606            if not self._dbcnx._tnx:
607                try:
608                    self._cnx.source().execute(sql)
609                except DatabaseError:
610                    raise  # database provides error message
611                except Exception as err:
612                    raise _op_error("Can't start transaction")
613                self._dbcnx._tnx = True
614            for parameters in seq_of_parameters:
615                sql = operation
616                if parameters:
617                    sql = self._quoteparams(sql, parameters)
618                rows = self._src.execute(sql)
619                if rows:  # true if not DML
620                    rowcount += rows
621                else:
622                    self.rowcount = -1
623        except DatabaseError:
624            raise  # database provides error message
625        except Error as err:
626            raise _db_error(
627                "Error in '%s': '%s' " % (sql, err), InterfaceError)
628        except Exception as err:
629            raise _op_error("Internal error in '%s': %s" % (sql, err))
630        # then initialize result raw count and description
631        if self._src.resulttype == RESULT_DQL:
632            self._description = True  # fetch on demand
633            self.rowcount = self._src.ntuples
634            self.lastrowid = None
635            if self.build_row_factory:
636                self.row_factory = self.build_row_factory()
637        else:
638            self.rowcount = rowcount
639            self.lastrowid = self._src.oidstatus()
640        # return the cursor object, so you can write statements such as
641        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
642        return self
643
644    def fetchone(self):
645        """Fetch the next row of a query result set."""
646        res = self.fetchmany(1, False)
647        try:
648            return res[0]
649        except IndexError:
650            return None
651
652    def fetchall(self):
653        """Fetch all (remaining) rows of a query result."""
654        return self.fetchmany(-1, False)
655
656    def fetchmany(self, size=None, keep=False):
657        """Fetch the next set of rows of a query result.
658
659        The number of rows to fetch per call is specified by the
660        size parameter. If it is not given, the cursor's arraysize
661        determines the number of rows to be fetched. If you set
662        the keep parameter to true, this is kept as new arraysize.
663        """
664        if size is None:
665            size = self.arraysize
666        if keep:
667            self.arraysize = size
668        try:
669            result = self._src.fetch(size)
670        except DatabaseError:
671            raise
672        except Error as err:
673            raise _db_error(str(err))
674        typecast = self.type_cache.typecast
675        return [self.row_factory([typecast(value, typ)
676            for typ, value in zip(self.coltypes, row)]) for row in result]
677
678    def callproc(self, procname, parameters=None):
679        """Call a stored database procedure with the given name.
680
681        The sequence of parameters must contain one entry for each input
682        argument that the procedure expects. The result of the call is the
683        same as this input sequence; replacement of output and input/output
684        parameters in the return value is currently not supported.
685
686        The procedure may also provide a result set as output. These can be
687        requested through the standard fetch methods of the cursor.
688        """
689        n = parameters and len(parameters) or 0
690        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
691        self.execute(query, parameters)
692        return parameters
693
694    def copy_from(self, stream, table,
695            format=None, sep=None, null=None, size=None, columns=None):
696        """Copy data from an input stream to the specified table.
697
698        The input stream can be a file-like object with a read() method or
699        it can also be an iterable returning a row or multiple rows of input
700        on each iteration.
701
702        The format must be text, csv or binary. The sep option sets the
703        column separator (delimiter) used in the non binary formats.
704        The null option sets the textual representation of NULL in the input.
705
706        The size option sets the size of the buffer used when reading data
707        from file-like objects.
708
709        The copy operation can be restricted to a subset of columns. If no
710        columns are specified, all of them will be copied.
711        """
712        binary_format = format == 'binary'
713        try:
714            read = stream.read
715        except AttributeError:
716            if size:
717                raise ValueError("Size must only be set for file-like objects")
718            if binary_format:
719                input_type = bytes
720                type_name = 'byte strings'
721            else:
722                input_type = basestring
723                type_name = 'strings'
724
725            if isinstance(stream, basestring):
726                if not isinstance(stream, input_type):
727                    raise ValueError("The input must be %s" % type_name)
728                if not binary_format:
729                    if isinstance(stream, str):
730                        if not stream.endswith('\n'):
731                            stream += '\n'
732                    else:
733                        if not stream.endswith(b'\n'):
734                            stream += b'\n'
735
736                def chunks():
737                    yield stream
738
739            elif isinstance(stream, Iterable):
740
741                def chunks():
742                    for chunk in stream:
743                        if not isinstance(chunk, input_type):
744                            raise ValueError(
745                                "Input stream must consist of %s" % type_name)
746                        if isinstance(chunk, str):
747                            if not chunk.endswith('\n'):
748                                chunk += '\n'
749                        else:
750                            if not chunk.endswith(b'\n'):
751                                chunk += b'\n'
752                        yield chunk
753
754            else:
755                raise TypeError("Need an input stream to copy from")
756        else:
757            if size is None:
758                size = 8192
759            elif not isinstance(size, int):
760                raise TypeError("The size option must be an integer")
761            if size > 0:
762
763                def chunks():
764                    while True:
765                        buffer = read(size)
766                        yield buffer
767                        if not buffer or len(buffer) < size:
768                            break
769
770            else:
771
772                def chunks():
773                    yield read()
774
775        if not table or not isinstance(table, basestring):
776            raise TypeError("Need a table to copy to")
777        if table.lower().startswith('select'):
778                raise ValueError("Must specify a table, not a query")
779        else:
780            table = '"%s"' % (table,)
781        operation = ['copy %s' % (table,)]
782        options = []
783        params = []
784        if format is not None:
785            if not isinstance(format, basestring):
786                raise TypeError("The frmat option must be be a string")
787            if format not in ('text', 'csv', 'binary'):
788                raise ValueError("Invalid format")
789            options.append('format %s' % (format,))
790        if sep is not None:
791            if not isinstance(sep, basestring):
792                raise TypeError("The sep option must be a string")
793            if format == 'binary':
794                raise ValueError(
795                    "The sep option is not allowed with binary format")
796            if len(sep) != 1:
797                raise ValueError(
798                    "The sep option must be a single one-byte character")
799            options.append('delimiter %s')
800            params.append(sep)
801        if null is not None:
802            if not isinstance(null, basestring):
803                raise TypeError("The null option must be a string")
804            options.append('null %s')
805            params.append(null)
806        if columns:
807            if not isinstance(columns, basestring):
808                columns = ','.join('"%s"' % (col,) for col in columns)
809            operation.append('(%s)' % (columns,))
810        operation.append("from stdin")
811        if options:
812            operation.append('(%s)' % ','.join(options))
813        operation = ' '.join(operation)
814
815        putdata = self._src.putdata
816        self.execute(operation, params)
817
818        try:
819            for chunk in chunks():
820                putdata(chunk)
821        except BaseException as error:
822            self.rowcount = -1
823            # the following call will re-raise the error
824            putdata(error)
825        else:
826            self.rowcount = putdata(None)
827
828        # return the cursor object, so you can chain operations
829        return self
830
831    def copy_to(self, stream, table,
832            format=None, sep=None, null=None, decode=None, columns=None):
833        """Copy data from the specified table to an output stream.
834
835        The output stream can be a file-like object with a write() method or
836        it can also be None, in which case the method will return a generator
837        yielding a row on each iteration.
838
839        Output will be returned as byte strings unless you set decode to true.
840
841        Note that you can also use a select query instead of the table name.
842
843        The format must be text, csv or binary. The sep option sets the
844        column separator (delimiter) used in the non binary formats.
845        The null option sets the textual representation of NULL in the output.
846
847        The copy operation can be restricted to a subset of columns. If no
848        columns are specified, all of them will be copied.
849        """
850        binary_format = format == 'binary'
851        if stream is not None:
852            try:
853                write = stream.write
854            except AttributeError:
855                raise TypeError("Need an output stream to copy to")
856        if not table or not isinstance(table, basestring):
857            raise TypeError("Need a table to copy to")
858        if table.lower().startswith('select'):
859            if columns:
860                raise ValueError("Columns must be specified in the query")
861            table = '(%s)' % (table,)
862        else:
863            table = '"%s"' % (table,)
864        operation = ['copy %s' % (table,)]
865        options = []
866        params = []
867        if format is not None:
868            if not isinstance(format, basestring):
869                raise TypeError("The format option must be a string")
870            if format not in ('text', 'csv', 'binary'):
871                raise ValueError("Invalid format")
872            options.append('format %s' % (format,))
873        if sep is not None:
874            if not isinstance(sep, basestring):
875                raise TypeError("The sep option must be a string")
876            if binary_format:
877                raise ValueError(
878                    "The sep option is not allowed with binary format")
879            if len(sep) != 1:
880                raise ValueError(
881                    "The sep option must be a single one-byte character")
882            options.append('delimiter %s')
883            params.append(sep)
884        if null is not None:
885            if not isinstance(null, basestring):
886                raise TypeError("The null option must be a string")
887            options.append('null %s')
888            params.append(null)
889        if decode is None:
890            if format == 'binary':
891                decode = False
892            else:
893                decode = str is unicode
894        else:
895            if not isinstance(decode, (int, bool)):
896                raise TypeError("The decode option must be a boolean")
897            if decode and binary_format:
898                raise ValueError(
899                    "The decode option is not allowed with binary format")
900        if columns:
901            if not isinstance(columns, basestring):
902                columns = ','.join('"%s"' % (col,) for col in columns)
903            operation.append('(%s)' % (columns,))
904
905        operation.append("to stdout")
906        if options:
907            operation.append('(%s)' % ','.join(options))
908        operation = ' '.join(operation)
909
910        getdata = self._src.getdata
911        self.execute(operation, params)
912
913        def copy():
914            self.rowcount = 0
915            while True:
916                row = getdata(decode)
917                if isinstance(row, int):
918                    if self.rowcount != row:
919                        self.rowcount = row
920                    break
921                self.rowcount += 1
922                yield row
923
924        if stream is None:
925            # no input stream, return the generator
926            return copy()
927
928        # write the rows to the file-like input stream
929        for row in copy():
930            write(row)
931
932        # return the cursor object, so you can chain operations
933        return self
934
935    def __next__(self):
936        """Return the next row (support for the iteration protocol)."""
937        res = self.fetchone()
938        if res is None:
939            raise StopIteration
940        return res
941
942    # Note that since Python 2.6 the iterator protocol uses __next()__
943    # instead of next(), we keep it only for backward compatibility of pgdb.
944    next = __next__
945
946    @staticmethod
947    def nextset():
948        """Not supported."""
949        raise NotSupportedError("The nextset() method is not supported")
950
951    @staticmethod
952    def setinputsizes(sizes):
953        """Not supported."""
954        pass  # unsupported, but silently passed
955
956    @staticmethod
957    def setoutputsize(size, column=0):
958        """Not supported."""
959        pass  # unsupported, but silently passed
960
961    @staticmethod
962    def row_factory(row):
963        """Process rows before they are returned.
964
965        You can overwrite this statically with a custom row factory, or
966        you can build a row factory dynamically with build_row_factory().
967
968        For example, you can create a Cursor class that returns rows as
969        Python dictionaries like this:
970
971            class DictCursor(pgdb.Cursor):
972
973                def row_factory(self, row):
974                    return {desc[0]: value
975                        for desc, value in zip(self.description, row)}
976
977            cur = DictCursor(con)  # get one DictCursor instance or
978            con.cursor_type = DictCursor  # always use DictCursor instances
979        """
980        raise NotImplementedError
981
982    def build_row_factory(self):
983        """Build a row factory based on the current description.
984
985        This implementation builds a row factory for creating named tuples.
986        You can overwrite this method if you want to dynamically create
987        different row factories whenever the column description changes.
988        """
989        colnames = self.colnames
990        if colnames:
991            try:
992                try:
993                    return namedtuple('Row', colnames, rename=True)._make
994                except TypeError:  # Python 2.6 and 3.0 do not support rename
995                    colnames = [v if v.isalnum() else 'column_%d' % n
996                             for n, v in enumerate(colnames)]
997                    return namedtuple('Row', colnames)._make
998            except ValueError:  # there is still a problem with the field names
999                colnames = ['column_%d' % n for n in range(len(colnames))]
1000                return namedtuple('Row', colnames)._make
1001
1002
1003CursorDescription = namedtuple('CursorDescription',
1004    ['name', 'type_code', 'display_size', 'internal_size',
1005     'precision', 'scale', 'null_ok'])
1006
1007
1008### Connection Objects
1009
1010class Connection(object):
1011    """Connection object."""
1012
1013    # expose the exceptions as attributes on the connection object
1014    Error = Error
1015    Warning = Warning
1016    InterfaceError = InterfaceError
1017    DatabaseError = DatabaseError
1018    InternalError = InternalError
1019    OperationalError = OperationalError
1020    ProgrammingError = ProgrammingError
1021    IntegrityError = IntegrityError
1022    DataError = DataError
1023    NotSupportedError = NotSupportedError
1024
1025    def __init__(self, cnx):
1026        """Create a database connection object."""
1027        self._cnx = cnx  # connection
1028        self._tnx = False  # transaction state
1029        self.type_cache = TypeCache(cnx)
1030        self.cursor_type = Cursor
1031        try:
1032            self._cnx.source()
1033        except Exception:
1034            raise _op_error("Invalid connection")
1035
1036    def __enter__(self):
1037        """Enter the runtime context for the connection object.
1038
1039        The runtime context can be used for running transactions.
1040        """
1041        return self
1042
1043    def __exit__(self, et, ev, tb):
1044        """Exit the runtime context for the connection object.
1045
1046        This does not close the connection, but it ends a transaction.
1047        """
1048        if et is None and ev is None and tb is None:
1049            self.commit()
1050        else:
1051            self.rollback()
1052
1053    def close(self):
1054        """Close the connection object."""
1055        if self._cnx:
1056            if self._tnx:
1057                try:
1058                    self.rollback()
1059                except DatabaseError:
1060                    pass
1061            self._cnx.close()
1062            self._cnx = None
1063        else:
1064            raise _op_error("Connection has been closed")
1065
1066    def commit(self):
1067        """Commit any pending transaction to the database."""
1068        if self._cnx:
1069            if self._tnx:
1070                self._tnx = False
1071                try:
1072                    self._cnx.source().execute("COMMIT")
1073                except DatabaseError:
1074                    raise
1075                except Exception:
1076                    raise _op_error("Can't commit")
1077        else:
1078            raise _op_error("Connection has been closed")
1079
1080    def rollback(self):
1081        """Roll back to the start of any pending transaction."""
1082        if self._cnx:
1083            if self._tnx:
1084                self._tnx = False
1085                try:
1086                    self._cnx.source().execute("ROLLBACK")
1087                except DatabaseError:
1088                    raise
1089                except Exception:
1090                    raise _op_error("Can't rollback")
1091        else:
1092            raise _op_error("Connection has been closed")
1093
1094    def cursor(self):
1095        """Return a new cursor object using the connection."""
1096        if self._cnx:
1097            try:
1098                return self.cursor_type(self)
1099            except Exception:
1100                raise _op_error("Invalid connection")
1101        else:
1102            raise _op_error("Connection has been closed")
1103
1104    if shortcutmethods:  # otherwise do not implement and document this
1105
1106        def execute(self, operation, params=None):
1107            """Shortcut method to run an operation on an implicit cursor."""
1108            cursor = self.cursor()
1109            cursor.execute(operation, params)
1110            return cursor
1111
1112        def executemany(self, operation, param_seq):
1113            """Shortcut method to run an operation against a sequence."""
1114            cursor = self.cursor()
1115            cursor.executemany(operation, param_seq)
1116            return cursor
1117
1118
1119### Module Interface
1120
1121_connect_ = connect
1122
1123def connect(dsn=None,
1124        user=None, password=None,
1125        host=None, database=None):
1126    """Connect to a database."""
1127    # first get params from DSN
1128    dbport = -1
1129    dbhost = ""
1130    dbbase = ""
1131    dbuser = ""
1132    dbpasswd = ""
1133    dbopt = ""
1134    try:
1135        params = dsn.split(":")
1136        dbhost = params[0]
1137        dbbase = params[1]
1138        dbuser = params[2]
1139        dbpasswd = params[3]
1140        dbopt = params[4]
1141    except (AttributeError, IndexError, TypeError):
1142        pass
1143
1144    # override if necessary
1145    if user is not None:
1146        dbuser = user
1147    if password is not None:
1148        dbpasswd = password
1149    if database is not None:
1150        dbbase = database
1151    if host is not None:
1152        try:
1153            params = host.split(":")
1154            dbhost = params[0]
1155            dbport = int(params[1])
1156        except (AttributeError, IndexError, TypeError, ValueError):
1157            pass
1158
1159    # empty host is localhost
1160    if dbhost == "":
1161        dbhost = None
1162    if dbuser == "":
1163        dbuser = None
1164
1165    # open the connection
1166    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
1167    return Connection(cnx)
1168
1169
1170### Types Handling
1171
1172class Type(frozenset):
1173    """Type class for a couple of PostgreSQL data types.
1174
1175    PostgreSQL is object-oriented: types are dynamic.
1176    We must thus use type names as internal type codes.
1177    """
1178
1179    def __new__(cls, values):
1180        if isinstance(values, basestring):
1181            values = values.split()
1182        return super(Type, cls).__new__(cls, values)
1183
1184    def __eq__(self, other):
1185        if isinstance(other, basestring):
1186            if other.startswith('_'):
1187                other = other[1:]
1188            return other in self
1189        else:
1190            return super(Type, self).__eq__(other)
1191
1192    def __ne__(self, other):
1193        if isinstance(other, basestring):
1194            if other.startswith('_'):
1195                other = other[1:]
1196            return other not in self
1197        else:
1198            return super(Type, self).__ne__(other)
1199
1200
1201class ArrayType:
1202    """Type class for PostgreSQL array types."""
1203
1204    def __eq__(self, other):
1205        if isinstance(other, basestring):
1206            return other.startswith('_')
1207        else:
1208            return isinstance(other, ArrayType)
1209
1210    def __ne__(self, other):
1211        if isinstance(other, basestring):
1212            return not other.startswith('_')
1213        else:
1214            return not isinstance(other, ArrayType)
1215
1216
1217class RecordType:
1218    """Type class for PostgreSQL record types."""
1219
1220    def __eq__(self, other):
1221        if isinstance(other, TypeCode):
1222            return other.type == 'c'
1223        elif isinstance(other, basestring):
1224            return other == 'record'
1225        else:
1226            return isinstance(other, RecordType)
1227
1228    def __ne__(self, other):
1229        if isinstance(other, TypeCode):
1230            return other.type != 'c'
1231        elif isinstance(other, basestring):
1232            return other != 'record'
1233        else:
1234            return not isinstance(other, RecordType)
1235
1236
1237# Mandatory type objects defined by DB-API 2 specs:
1238
1239STRING = Type('char bpchar name text varchar')
1240BINARY = Type('bytea')
1241NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1242DATETIME = Type('date time timetz timestamp timestamptz interval'
1243    ' abstime reltime')  # these are very old
1244ROWID = Type('oid')
1245
1246
1247# Additional type objects (more specific):
1248
1249BOOL = Type('bool')
1250SMALLINT = Type('int2')
1251INTEGER = Type('int2 int4 int8 serial')
1252LONG = Type('int8')
1253FLOAT = Type('float4 float8')
1254NUMERIC = Type('numeric')
1255MONEY = Type('money')
1256DATE = Type('date')
1257TIME = Type('time timetz')
1258TIMESTAMP = Type('timestamp timestamptz')
1259INTERVAL = Type('interval')
1260JSON = Type('json jsonb')
1261
1262# Type object for arrays (also equate to their base types):
1263
1264ARRAY = ArrayType()
1265
1266# Type object for records (encompassing all composite types):
1267
1268RECORD = RecordType()
1269
1270
1271# Mandatory type helpers defined by DB-API 2 specs:
1272
1273def Date(year, month, day):
1274    """Construct an object holding a date value."""
1275    return date(year, month, day)
1276
1277
1278def Time(hour, minute=0, second=0, microsecond=0):
1279    """Construct an object holding a time value."""
1280    return time(hour, minute, second, microsecond)
1281
1282
1283def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1284    """Construct an object holding a time stamp value."""
1285    return datetime(year, month, day, hour, minute, second, microsecond)
1286
1287
1288def DateFromTicks(ticks):
1289    """Construct an object holding a date value from the given ticks value."""
1290    return Date(*localtime(ticks)[:3])
1291
1292
1293def TimeFromTicks(ticks):
1294    """Construct an object holding a time value from the given ticks value."""
1295    return Time(*localtime(ticks)[3:6])
1296
1297
1298def TimestampFromTicks(ticks):
1299    """Construct an object holding a time stamp from the given ticks value."""
1300    return Timestamp(*localtime(ticks)[:6])
1301
1302
1303class Binary(bytes):
1304    """Construct an object capable of holding a binary (long) string value."""
1305
1306
1307# Additional type helpers for PyGreSQL:
1308
1309class Bytea(bytes):
1310    """Construct an object capable of holding a bytea value."""
1311
1312
1313class Json:
1314    """Construct a wrapper for holding an object serializable to JSON."""
1315
1316    def __init__(self, obj, encode=None):
1317        self.obj = obj
1318        self.encode = encode or jsonencode
1319
1320    def __str__(self):
1321        obj = self.obj
1322        if isinstance(obj, basestring):
1323            return obj
1324        return self.encode(obj)
1325
1326    __pg_repr__ = __str__
1327
1328
1329class Literal:
1330    """Construct a wrapper for holding a literal SQL string."""
1331
1332    def __init__(self, sql):
1333        self.sql = sql
1334
1335    def __str__(self):
1336        return self.sql
1337
1338    __pg_repr__ = __str__
1339
1340
1341# If run as script, print some information:
1342
1343if __name__ == '__main__':
1344    print('PyGreSQL version', version)
1345    print('')
1346    print(__doc__)
Note: See TracBrowser for help on using the repository browser.