source: trunk/pgdb.py @ 812

Last change on this file since 812 was 812, checked in by cito, 3 years ago

Fix two minor issues with pgdb

Removed a now unnecessary import and fixed one forgotten name change.

  • Property svn:keywords set to Id
File size: 43.6 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 812 2016-02-01 11:21:38Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from json import loads as jsondecode, dumps as jsonencode
76
77try:
78    long
79except NameError:  # Python >= 3.0
80    long = int
81
82try:
83    unicode
84except NameError:  # Python >= 3.0
85    unicode = str
86
87try:
88    basestring
89except NameError:  # Python >= 3.0
90    basestring = (str, bytes)
91
92from collections import Iterable
93
94
95### Module Constants
96
97# compliant with DB API 2.0
98apilevel = '2.0'
99
100# module may be shared, but not connections
101threadsafety = 1
102
103# this module use extended python format codes
104paramstyle = 'pyformat'
105
106# shortcut methods have been excluded from DB API 2 and
107# are not recommended by the DB SIG, but they can be handy
108shortcutmethods = 1
109
110
111### Internal Type Handling
112
113def decimal_type(decimal_type=None):
114    """Get or set global type to be used for decimal values.
115
116    Note that connections cache cast functions. To be sure a global change
117    is picked up by a running connection, call con.type_cache.reset_typecast().
118    """
119    global Decimal
120    if decimal_type is not None:
121        Decimal = decimal_type
122        set_typecast('numeric', decimal_type)
123    return Decimal
124
125
126def cast_bool(value):
127    """Cast boolean value in database format to bool."""
128    if value:
129        return value[0] in ('t', 'T')
130
131
132def cast_money(value):
133    """Cast money value in database format to Decimal."""
134    if value:
135        value = value.replace('(', '-')
136        return Decimal(''.join(c for c in value if c.isdigit() or c in '.-'))
137
138
139def cast_int2vector(value):
140    """Cast an int2vector value."""
141    return [int(v) for v in value.split()]
142
143
144class Typecasts(dict):
145    """Dictionary mapping database types to typecast functions.
146
147    The cast functions get passed the string representation of a value in
148    the database which they need to convert to a Python object.  The
149    passed string will never be None since NULL values are already be
150    handled before the cast function is called.
151    """
152
153    # the default cast functions
154    # (str functions are ignored but have been added for faster access)
155    defaults = {'char': str, 'bpchar': str, 'name': str,
156        'text': str, 'varchar': str,
157        'bool': cast_bool, 'bytea': unescape_bytea,
158        'int2': int, 'int4': int, 'serial': int,
159        'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
160        'oid': long, 'oid8': long,
161        'float4': float, 'float8': float,
162        'numeric': Decimal, 'money': cast_money,
163        'int2vector': cast_int2vector,
164        'anyarray': cast_array, 'record': cast_record}
165
166    def __missing__(self, typ):
167        """Create a cast function if it is not cached.
168
169        Note that this class never raises a KeyError,
170        but returns None when no special cast function exists.
171        """
172        if not isinstance(typ, str):
173            raise TypeError('Invalid type: %s' % typ)
174        cast = self.defaults.get(typ)
175        if cast:
176            # store default for faster access
177            self[typ] = cast
178        elif typ.startswith('_'):
179            # create array cast
180            base_cast = self[typ[1:]]
181            cast = self.create_array_cast(base_cast)
182            if base_cast:
183                # store only if base type exists
184                self[typ] = cast
185        return cast
186
187    def get(self, typ, default=None):
188        """Get the typecast function for the given database type."""
189        return self[typ] or default
190
191    def set(self, typ, cast):
192        """Set a typecast function for the specified database type(s)."""
193        if isinstance(typ, basestring):
194            typ = [typ]
195        if cast is None:
196            for t in typ:
197                self.pop(t, None)
198                self.pop('_%s' % t, None)
199        else:
200            if not callable(cast):
201                raise TypeError("Cast parameter must be callable")
202            for t in typ:
203                self[t] = cast
204                self.pop('_%s' % t, None)
205
206    def reset(self, typ=None):
207        """Reset the typecasts for the specified type(s) to their defaults.
208
209        When no type is specified, all typecasts will be reset.
210        """
211        defaults = self.defaults
212        if typ is None:
213            self.clear()
214            self.update(defaults)
215        else:
216            if isinstance(typ, basestring):
217                typ = [typ]
218            for t in typ:
219                cast = defaults.get(t)
220                if cast:
221                    self[t] = cast
222                    t = '_%s' % t
223                    cast = defaults.get(t)
224                    if cast:
225                        self[t] = cast
226                    else:
227                        self.pop(t, None)
228                else:
229                    self.pop(t, None)
230                    self.pop('_%s' % t, None)
231
232    def create_array_cast(self, cast):
233        """Create an array typecast for the given base cast."""
234        return lambda v: cast_array(v, cast)
235
236    def create_record_cast(self, name, fields, casts):
237        """Create a named record typecast for the given fields and casts."""
238        record = namedtuple(name, fields)
239        return lambda v: record(*cast_record(v, casts))
240
241
242_typecasts = Typecasts()  # this is the global typecast dictionary
243
244
245def get_typecast(typ):
246    """Get the global typecast function for the given database type(s)."""
247    return _typecasts.get(typ)
248
249
250def set_typecast(typ, cast):
251    """Set a global typecast function for the given database type(s).
252
253    Note that connections cache cast functions. To be sure a global change
254    is picked up by a running connection, call con.type_cache.reset_typecast().
255    """
256    _typecasts.set(typ, cast)
257
258
259def reset_typecast(typ=None):
260    """Reset the global typecasts for the given type(s) to their default.
261
262    When no type is specified, all typecasts will be reset.
263
264    Note that connections cache cast functions. To be sure a global change
265    is picked up by a running connection, call con.type_cache.reset_typecast().
266    """
267    _typecasts.reset(typ)
268
269
270class LocalTypecasts(Typecasts):
271    """Map typecasts, including local composite types, to cast functions."""
272
273    defaults = _typecasts
274
275    def __missing__(self, typ):
276        """Create a cast function if it is not cached."""
277        if typ.startswith('_'):
278            base_cast = self[typ[1:]]
279            cast = self.create_array_cast(base_cast)
280            if base_cast:
281                self[typ] = cast
282        else:
283            cast = self.defaults.get(typ)
284            if cast:
285                self[typ] = cast
286            else:
287                fields = self.get_fields(typ)
288                if fields:
289                    casts = [self[field.type] for field in fields]
290                    fields = [field.name for field in fields]
291                    cast = self.create_record_cast(typ, fields, casts)
292                    self[typ] = cast
293        return cast
294
295    def get_fields(self, typ):
296        """Return the fields for the given record type.
297
298        This method will be replaced with a method that looks up the fields
299        using the type cache of the connection.
300        """
301        return []
302
303
304class TypeCode(str):
305    """Class representing the type_code used by the DB-API 2.0.
306
307    TypeCode objects are strings equal to the PostgreSQL type name,
308    but carry some additional information.
309    """
310
311    @classmethod
312    def create(cls, oid, name, len, type, category, delim, relid):
313        """Create a type code for a PostgreSQL data type."""
314        self = cls(name)
315        self.oid = oid
316        self.len = len
317        self.type = type
318        self.category = category
319        self.delim = delim
320        self.relid = relid
321        return self
322
323FieldInfo = namedtuple('FieldInfo', ['name', 'type'])
324
325
326class TypeCache(dict):
327    """Cache for database types.
328
329    This cache maps type OIDs and names to TypeCode strings containing
330    important information on the associated database type.
331    """
332
333    def __init__(self, cnx):
334        """Initialize type cache for connection."""
335        super(TypeCache, self).__init__()
336        self._escape_string = cnx.escape_string
337        self._src = cnx.source()
338        self._typecasts = LocalTypecasts()
339        self._typecasts.get_fields = self.get_fields
340
341    def __missing__(self, key):
342        """Get the type info from the database if it is not cached."""
343        if isinstance(key, int):
344            oid = key
345        else:
346            if '.' not in key and '"' not in key:
347                key = '"%s"' % key
348            oid = "'%s'::regtype" % self._escape_string(key)
349        try:
350            self._src.execute("SELECT oid, typname,"
351                 " typlen, typtype, typcategory, typdelim, typrelid"
352                " FROM pg_type WHERE oid=%s" % oid)
353        except ProgrammingError:
354            res = None
355        else:
356            res = self._src.fetch(1)
357        if not res:
358            raise KeyError('Type %s could not be found' % key)
359        res = res[0]
360        type_code = TypeCode.create(int(res[0]), res[1],
361            int(res[2]), res[3], res[4], res[5], int(res[6]))
362        self[type_code.oid] = self[str(type_code)] = type_code
363        return type_code
364
365    def get(self, key, default=None):
366        """Get the type even if it is not cached."""
367        try:
368            return self[key]
369        except KeyError:
370            return default
371
372    def get_fields(self, typ):
373        """Get the names and types of the fields of composite types."""
374        if not isinstance(typ, TypeCode):
375            typ = self.get(typ)
376            if not typ:
377                return None
378        if not typ.relid:
379            return None  # this type is not composite
380        self._src.execute("SELECT attname, atttypid"
381            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
382            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
383        return [FieldInfo(name, self.get(int(oid)))
384            for name, oid in self._src.fetch(-1)]
385
386    def get_typecast(self, typ):
387        """Get the typecast function for the given database type."""
388        return self._typecasts.get(typ)
389
390    def set_typecast(self, typ, cast):
391        """Set a typecast function for the specified database type(s)."""
392        self._typecasts.set(typ, cast)
393
394    def reset_typecast(self, typ=None):
395        """Reset the typecast function for the specified database type(s)."""
396        self._typecasts.reset(typ)
397
398    def typecast(self, value, typ):
399        """Cast the given value according to the given database type."""
400        if value is None:
401            # for NULL values, no typecast is necessary
402            return None
403        cast = self.get_typecast(typ)
404        if not cast or cast is str:
405            # no typecast is necessary
406            return value
407        return cast(value)
408
409
410class _quotedict(dict):
411    """Dictionary with auto quoting of its items.
412
413    The quote attribute must be set to the desired quote function.
414    """
415
416    def __getitem__(self, key):
417        return self.quote(super(_quotedict, self).__getitem__(key))
418
419
420### Error messages
421
422def _db_error(msg, cls=DatabaseError):
423    """Return DatabaseError with empty sqlstate attribute."""
424    error = cls(msg)
425    error.sqlstate = None
426    return error
427
428
429def _op_error(msg):
430    """Return OperationalError."""
431    return _db_error(msg, OperationalError)
432
433
434### Cursor Object
435
436class Cursor(object):
437    """Cursor object."""
438
439    def __init__(self, dbcnx):
440        """Create a cursor object for the database connection."""
441        self.connection = self._dbcnx = dbcnx
442        self._cnx = dbcnx._cnx
443        self.type_cache = dbcnx.type_cache
444        self._src = self._cnx.source()
445        # the official attribute for describing the result columns
446        self._description = None
447        if self.row_factory is Cursor.row_factory:
448            # the row factory needs to be determined dynamically
449            self.row_factory = None
450        else:
451            self.build_row_factory = None
452        self.rowcount = -1
453        self.arraysize = 1
454        self.lastrowid = None
455
456    def __iter__(self):
457        """Make cursor compatible to the iteration protocol."""
458        return self
459
460    def __enter__(self):
461        """Enter the runtime context for the cursor object."""
462        return self
463
464    def __exit__(self, et, ev, tb):
465        """Exit the runtime context for the cursor object."""
466        self.close()
467
468    def _quote(self, value):
469        """Quote value depending on its type."""
470        if value is None:
471            return 'NULL'
472        if isinstance(value, (datetime, date, time, timedelta, Json)):
473            value = str(value)
474        if isinstance(value, basestring):
475            if isinstance(value, Binary):
476                value = self._cnx.escape_bytea(value)
477                if bytes is not str:  # Python >= 3.0
478                    value = value.decode('ascii')
479            else:
480                value = self._cnx.escape_string(value)
481            return "'%s'" % value
482        if isinstance(value, float):
483            if isinf(value):
484                return "'-Infinity'" if value < 0 else "'Infinity'"
485            if isnan(value):
486                return "'NaN'"
487            return value
488        if isinstance(value, (int, long, Decimal, Literal)):
489            return value
490        if isinstance(value, list):
491            # Quote value as an ARRAY constructor. This is better than using
492            # an array literal because it carries the information that this is
493            # an array and not a string.  One issue with this syntax is that
494            # you need to add an explicit typecast when passing empty arrays.
495            # The ARRAY keyword is actually only necessary at the top level.
496            q = self._quote
497            return 'ARRAY[%s]' % ','.join(str(q(v)) for v in value)
498        if isinstance(value, tuple):
499            # Quote as a ROW constructor.  This is better than using a record
500            # literal because it carries the information that this is a record
501            # and not a string.  We don't use the keyword ROW in order to make
502            # this usable with the IN syntax as well.  It is only necessary
503            # when the records has a single column which is not really useful.
504            q = self._quote
505            return '(%s)' % ','.join(str(q(v)) for v in value)
506        try:
507            value = value.__pg_repr__()
508        except AttributeError:
509            raise InterfaceError(
510                'Do not know how to adapt type %s' % type(value))
511        if isinstance(value, (tuple, list)):
512            value = self._quote(value)
513        return value
514
515    def _quoteparams(self, string, parameters):
516        """Quote parameters.
517
518        This function works for both mappings and sequences.
519        """
520        if isinstance(parameters, dict):
521            parameters = _quotedict(parameters)
522            parameters.quote = self._quote
523        else:
524            parameters = tuple(map(self._quote, parameters))
525        return string % parameters
526
527    def _make_description(self, info):
528        """Make the description tuple for the given field info."""
529        name, typ, size, mod = info[1:]
530        type_code = self.type_cache[typ]
531        if mod > 0:
532            mod -= 4
533        if type_code == 'numeric':
534            precision, scale = mod >> 16, mod & 0xffff
535            size = precision
536        else:
537            if not size:
538                size = type_code.size
539            if size == -1:
540                size = mod
541            precision = scale = None
542        return CursorDescription(name, type_code,
543            None, size, precision, scale, None)
544
545    @property
546    def description(self):
547        """Read-only attribute describing the result columns."""
548        descr = self._description
549        if self._description is True:
550            make = self._make_description
551            descr = [make(info) for info in self._src.listinfo()]
552            self._description = descr
553        return descr
554
555    @property
556    def colnames(self):
557        """Unofficial convenience method for getting the column names."""
558        return [d[0] for d in self.description]
559
560    @property
561    def coltypes(self):
562        """Unofficial convenience method for getting the column types."""
563        return [d[1] for d in self.description]
564
565    def close(self):
566        """Close the cursor object."""
567        self._src.close()
568        self._description = None
569        self.rowcount = -1
570        self.lastrowid = None
571
572    def execute(self, operation, parameters=None):
573        """Prepare and execute a database operation (query or command)."""
574        # The parameters may also be specified as list of tuples to e.g.
575        # insert multiple rows in a single operation, but this kind of
576        # usage is deprecated.  We make several plausibility checks because
577        # tuples can also be passed with the meaning of ROW constructors.
578        if (parameters and isinstance(parameters, list)
579                and len(parameters) > 1
580                and all(isinstance(p, tuple) for p in parameters)
581                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
582            return self.executemany(operation, parameters)
583        else:
584            # not a list of tuples
585            return self.executemany(operation, [parameters])
586
587    def executemany(self, operation, seq_of_parameters):
588        """Prepare operation and execute it against a parameter sequence."""
589        if not seq_of_parameters:
590            # don't do anything without parameters
591            return
592        self._description = None
593        self.rowcount = -1
594        # first try to execute all queries
595        rowcount = 0
596        sql = "BEGIN"
597        try:
598            if not self._dbcnx._tnx:
599                try:
600                    self._cnx.source().execute(sql)
601                except DatabaseError:
602                    raise  # database provides error message
603                except Exception as err:
604                    raise _op_error("Can't start transaction")
605                self._dbcnx._tnx = True
606            for parameters in seq_of_parameters:
607                sql = operation
608                if parameters:
609                    sql = self._quoteparams(sql, parameters)
610                rows = self._src.execute(sql)
611                if rows:  # true if not DML
612                    rowcount += rows
613                else:
614                    self.rowcount = -1
615        except DatabaseError:
616            raise  # database provides error message
617        except Error as err:
618            raise _db_error(
619                "Error in '%s': '%s' " % (sql, err), InterfaceError)
620        except Exception as err:
621            raise _op_error("Internal error in '%s': %s" % (sql, err))
622        # then initialize result raw count and description
623        if self._src.resulttype == RESULT_DQL:
624            self._description = True  # fetch on demand
625            self.rowcount = self._src.ntuples
626            self.lastrowid = None
627            if self.build_row_factory:
628                self.row_factory = self.build_row_factory()
629        else:
630            self.rowcount = rowcount
631            self.lastrowid = self._src.oidstatus()
632        # return the cursor object, so you can write statements such as
633        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
634        return self
635
636    def fetchone(self):
637        """Fetch the next row of a query result set."""
638        res = self.fetchmany(1, False)
639        try:
640            return res[0]
641        except IndexError:
642            return None
643
644    def fetchall(self):
645        """Fetch all (remaining) rows of a query result."""
646        return self.fetchmany(-1, False)
647
648    def fetchmany(self, size=None, keep=False):
649        """Fetch the next set of rows of a query result.
650
651        The number of rows to fetch per call is specified by the
652        size parameter. If it is not given, the cursor's arraysize
653        determines the number of rows to be fetched. If you set
654        the keep parameter to true, this is kept as new arraysize.
655        """
656        if size is None:
657            size = self.arraysize
658        if keep:
659            self.arraysize = size
660        try:
661            result = self._src.fetch(size)
662        except DatabaseError:
663            raise
664        except Error as err:
665            raise _db_error(str(err))
666        typecast = self.type_cache.typecast
667        return [self.row_factory([typecast(value, typ)
668            for typ, value in zip(self.coltypes, row)]) for row in result]
669
670    def callproc(self, procname, parameters=None):
671        """Call a stored database procedure with the given name.
672
673        The sequence of parameters must contain one entry for each input
674        argument that the procedure expects. The result of the call is the
675        same as this input sequence; replacement of output and input/output
676        parameters in the return value is currently not supported.
677
678        The procedure may also provide a result set as output. These can be
679        requested through the standard fetch methods of the cursor.
680        """
681        n = parameters and len(parameters) or 0
682        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
683        self.execute(query, parameters)
684        return parameters
685
686    def copy_from(self, stream, table,
687            format=None, sep=None, null=None, size=None, columns=None):
688        """Copy data from an input stream to the specified table.
689
690        The input stream can be a file-like object with a read() method or
691        it can also be an iterable returning a row or multiple rows of input
692        on each iteration.
693
694        The format must be text, csv or binary. The sep option sets the
695        column separator (delimiter) used in the non binary formats.
696        The null option sets the textual representation of NULL in the input.
697
698        The size option sets the size of the buffer used when reading data
699        from file-like objects.
700
701        The copy operation can be restricted to a subset of columns. If no
702        columns are specified, all of them will be copied.
703        """
704        binary_format = format == 'binary'
705        try:
706            read = stream.read
707        except AttributeError:
708            if size:
709                raise ValueError("Size must only be set for file-like objects")
710            if binary_format:
711                input_type = bytes
712                type_name = 'byte strings'
713            else:
714                input_type = basestring
715                type_name = 'strings'
716
717            if isinstance(stream, basestring):
718                if not isinstance(stream, input_type):
719                    raise ValueError("The input must be %s" % type_name)
720                if not binary_format:
721                    if isinstance(stream, str):
722                        if not stream.endswith('\n'):
723                            stream += '\n'
724                    else:
725                        if not stream.endswith(b'\n'):
726                            stream += b'\n'
727
728                def chunks():
729                    yield stream
730
731            elif isinstance(stream, Iterable):
732
733                def chunks():
734                    for chunk in stream:
735                        if not isinstance(chunk, input_type):
736                            raise ValueError(
737                                "Input stream must consist of %s" % type_name)
738                        if isinstance(chunk, str):
739                            if not chunk.endswith('\n'):
740                                chunk += '\n'
741                        else:
742                            if not chunk.endswith(b'\n'):
743                                chunk += b'\n'
744                        yield chunk
745
746            else:
747                raise TypeError("Need an input stream to copy from")
748        else:
749            if size is None:
750                size = 8192
751            elif not isinstance(size, int):
752                raise TypeError("The size option must be an integer")
753            if size > 0:
754
755                def chunks():
756                    while True:
757                        buffer = read(size)
758                        yield buffer
759                        if not buffer or len(buffer) < size:
760                            break
761
762            else:
763
764                def chunks():
765                    yield read()
766
767        if not table or not isinstance(table, basestring):
768            raise TypeError("Need a table to copy to")
769        if table.lower().startswith('select'):
770                raise ValueError("Must specify a table, not a query")
771        else:
772            table = '"%s"' % (table,)
773        operation = ['copy %s' % (table,)]
774        options = []
775        params = []
776        if format is not None:
777            if not isinstance(format, basestring):
778                raise TypeError("The frmat option must be be a string")
779            if format not in ('text', 'csv', 'binary'):
780                raise ValueError("Invalid format")
781            options.append('format %s' % (format,))
782        if sep is not None:
783            if not isinstance(sep, basestring):
784                raise TypeError("The sep option must be a string")
785            if format == 'binary':
786                raise ValueError(
787                    "The sep option is not allowed with binary format")
788            if len(sep) != 1:
789                raise ValueError(
790                    "The sep option must be a single one-byte character")
791            options.append('delimiter %s')
792            params.append(sep)
793        if null is not None:
794            if not isinstance(null, basestring):
795                raise TypeError("The null option must be a string")
796            options.append('null %s')
797            params.append(null)
798        if columns:
799            if not isinstance(columns, basestring):
800                columns = ','.join('"%s"' % (col,) for col in columns)
801            operation.append('(%s)' % (columns,))
802        operation.append("from stdin")
803        if options:
804            operation.append('(%s)' % ','.join(options))
805        operation = ' '.join(operation)
806
807        putdata = self._src.putdata
808        self.execute(operation, params)
809
810        try:
811            for chunk in chunks():
812                putdata(chunk)
813        except BaseException as error:
814            self.rowcount = -1
815            # the following call will re-raise the error
816            putdata(error)
817        else:
818            self.rowcount = putdata(None)
819
820        # return the cursor object, so you can chain operations
821        return self
822
823    def copy_to(self, stream, table,
824            format=None, sep=None, null=None, decode=None, columns=None):
825        """Copy data from the specified table to an output stream.
826
827        The output stream can be a file-like object with a write() method or
828        it can also be None, in which case the method will return a generator
829        yielding a row on each iteration.
830
831        Output will be returned as byte strings unless you set decode to true.
832
833        Note that you can also use a select query instead of the table name.
834
835        The format must be text, csv or binary. The sep option sets the
836        column separator (delimiter) used in the non binary formats.
837        The null option sets the textual representation of NULL in the output.
838
839        The copy operation can be restricted to a subset of columns. If no
840        columns are specified, all of them will be copied.
841        """
842        binary_format = format == 'binary'
843        if stream is not None:
844            try:
845                write = stream.write
846            except AttributeError:
847                raise TypeError("Need an output stream to copy to")
848        if not table or not isinstance(table, basestring):
849            raise TypeError("Need a table to copy to")
850        if table.lower().startswith('select'):
851            if columns:
852                raise ValueError("Columns must be specified in the query")
853            table = '(%s)' % (table,)
854        else:
855            table = '"%s"' % (table,)
856        operation = ['copy %s' % (table,)]
857        options = []
858        params = []
859        if format is not None:
860            if not isinstance(format, basestring):
861                raise TypeError("The format option must be a string")
862            if format not in ('text', 'csv', 'binary'):
863                raise ValueError("Invalid format")
864            options.append('format %s' % (format,))
865        if sep is not None:
866            if not isinstance(sep, basestring):
867                raise TypeError("The sep option must be a string")
868            if binary_format:
869                raise ValueError(
870                    "The sep option is not allowed with binary format")
871            if len(sep) != 1:
872                raise ValueError(
873                    "The sep option must be a single one-byte character")
874            options.append('delimiter %s')
875            params.append(sep)
876        if null is not None:
877            if not isinstance(null, basestring):
878                raise TypeError("The null option must be a string")
879            options.append('null %s')
880            params.append(null)
881        if decode is None:
882            if format == 'binary':
883                decode = False
884            else:
885                decode = str is unicode
886        else:
887            if not isinstance(decode, (int, bool)):
888                raise TypeError("The decode option must be a boolean")
889            if decode and binary_format:
890                raise ValueError(
891                    "The decode option is not allowed with binary format")
892        if columns:
893            if not isinstance(columns, basestring):
894                columns = ','.join('"%s"' % (col,) for col in columns)
895            operation.append('(%s)' % (columns,))
896
897        operation.append("to stdout")
898        if options:
899            operation.append('(%s)' % ','.join(options))
900        operation = ' '.join(operation)
901
902        getdata = self._src.getdata
903        self.execute(operation, params)
904
905        def copy():
906            self.rowcount = 0
907            while True:
908                row = getdata(decode)
909                if isinstance(row, int):
910                    if self.rowcount != row:
911                        self.rowcount = row
912                    break
913                self.rowcount += 1
914                yield row
915
916        if stream is None:
917            # no input stream, return the generator
918            return copy()
919
920        # write the rows to the file-like input stream
921        for row in copy():
922            write(row)
923
924        # return the cursor object, so you can chain operations
925        return self
926
927    def __next__(self):
928        """Return the next row (support for the iteration protocol)."""
929        res = self.fetchone()
930        if res is None:
931            raise StopIteration
932        return res
933
934    # Note that since Python 2.6 the iterator protocol uses __next()__
935    # instead of next(), we keep it only for backward compatibility of pgdb.
936    next = __next__
937
938    @staticmethod
939    def nextset():
940        """Not supported."""
941        raise NotSupportedError("The nextset() method is not supported")
942
943    @staticmethod
944    def setinputsizes(sizes):
945        """Not supported."""
946        pass  # unsupported, but silently passed
947
948    @staticmethod
949    def setoutputsize(size, column=0):
950        """Not supported."""
951        pass  # unsupported, but silently passed
952
953    @staticmethod
954    def row_factory(row):
955        """Process rows before they are returned.
956
957        You can overwrite this statically with a custom row factory, or
958        you can build a row factory dynamically with build_row_factory().
959
960        For example, you can create a Cursor class that returns rows as
961        Python dictionaries like this:
962
963            class DictCursor(pgdb.Cursor):
964
965                def row_factory(self, row):
966                    return {desc[0]: value
967                        for desc, value in zip(self.description, row)}
968
969            cur = DictCursor(con)  # get one DictCursor instance or
970            con.cursor_type = DictCursor  # always use DictCursor instances
971        """
972        raise NotImplementedError
973
974    def build_row_factory(self):
975        """Build a row factory based on the current description.
976
977        This implementation builds a row factory for creating named tuples.
978        You can overwrite this method if you want to dynamically create
979        different row factories whenever the column description changes.
980        """
981        colnames = self.colnames
982        if colnames:
983            try:
984                try:
985                    return namedtuple('Row', colnames, rename=True)._make
986                except TypeError:  # Python 2.6 and 3.0 do not support rename
987                    colnames = [v if v.isalnum() else 'column_%d' % n
988                             for n, v in enumerate(colnames)]
989                    return namedtuple('Row', colnames)._make
990            except ValueError:  # there is still a problem with the field names
991                colnames = ['column_%d' % n for n in range(len(colnames))]
992                return namedtuple('Row', colnames)._make
993
994
995CursorDescription = namedtuple('CursorDescription',
996    ['name', 'type_code', 'display_size', 'internal_size',
997     'precision', 'scale', 'null_ok'])
998
999
1000### Connection Objects
1001
1002class Connection(object):
1003    """Connection object."""
1004
1005    # expose the exceptions as attributes on the connection object
1006    Error = Error
1007    Warning = Warning
1008    InterfaceError = InterfaceError
1009    DatabaseError = DatabaseError
1010    InternalError = InternalError
1011    OperationalError = OperationalError
1012    ProgrammingError = ProgrammingError
1013    IntegrityError = IntegrityError
1014    DataError = DataError
1015    NotSupportedError = NotSupportedError
1016
1017    def __init__(self, cnx):
1018        """Create a database connection object."""
1019        self._cnx = cnx  # connection
1020        self._tnx = False  # transaction state
1021        self.type_cache = TypeCache(cnx)
1022        self.cursor_type = Cursor
1023        try:
1024            self._cnx.source()
1025        except Exception:
1026            raise _op_error("Invalid connection")
1027
1028    def __enter__(self):
1029        """Enter the runtime context for the connection object.
1030
1031        The runtime context can be used for running transactions.
1032        """
1033        return self
1034
1035    def __exit__(self, et, ev, tb):
1036        """Exit the runtime context for the connection object.
1037
1038        This does not close the connection, but it ends a transaction.
1039        """
1040        if et is None and ev is None and tb is None:
1041            self.commit()
1042        else:
1043            self.rollback()
1044
1045    def close(self):
1046        """Close the connection object."""
1047        if self._cnx:
1048            if self._tnx:
1049                try:
1050                    self.rollback()
1051                except DatabaseError:
1052                    pass
1053            self._cnx.close()
1054            self._cnx = None
1055        else:
1056            raise _op_error("Connection has been closed")
1057
1058    def commit(self):
1059        """Commit any pending transaction to the database."""
1060        if self._cnx:
1061            if self._tnx:
1062                self._tnx = False
1063                try:
1064                    self._cnx.source().execute("COMMIT")
1065                except DatabaseError:
1066                    raise
1067                except Exception:
1068                    raise _op_error("Can't commit")
1069        else:
1070            raise _op_error("Connection has been closed")
1071
1072    def rollback(self):
1073        """Roll back to the start of any pending transaction."""
1074        if self._cnx:
1075            if self._tnx:
1076                self._tnx = False
1077                try:
1078                    self._cnx.source().execute("ROLLBACK")
1079                except DatabaseError:
1080                    raise
1081                except Exception:
1082                    raise _op_error("Can't rollback")
1083        else:
1084            raise _op_error("Connection has been closed")
1085
1086    def cursor(self):
1087        """Return a new cursor object using the connection."""
1088        if self._cnx:
1089            try:
1090                return self.cursor_type(self)
1091            except Exception:
1092                raise _op_error("Invalid connection")
1093        else:
1094            raise _op_error("Connection has been closed")
1095
1096    if shortcutmethods:  # otherwise do not implement and document this
1097
1098        def execute(self, operation, params=None):
1099            """Shortcut method to run an operation on an implicit cursor."""
1100            cursor = self.cursor()
1101            cursor.execute(operation, params)
1102            return cursor
1103
1104        def executemany(self, operation, param_seq):
1105            """Shortcut method to run an operation against a sequence."""
1106            cursor = self.cursor()
1107            cursor.executemany(operation, param_seq)
1108            return cursor
1109
1110
1111### Module Interface
1112
1113_connect_ = connect
1114
1115def connect(dsn=None,
1116        user=None, password=None,
1117        host=None, database=None):
1118    """Connect to a database."""
1119    # first get params from DSN
1120    dbport = -1
1121    dbhost = ""
1122    dbbase = ""
1123    dbuser = ""
1124    dbpasswd = ""
1125    dbopt = ""
1126    try:
1127        params = dsn.split(":")
1128        dbhost = params[0]
1129        dbbase = params[1]
1130        dbuser = params[2]
1131        dbpasswd = params[3]
1132        dbopt = params[4]
1133    except (AttributeError, IndexError, TypeError):
1134        pass
1135
1136    # override if necessary
1137    if user is not None:
1138        dbuser = user
1139    if password is not None:
1140        dbpasswd = password
1141    if database is not None:
1142        dbbase = database
1143    if host is not None:
1144        try:
1145            params = host.split(":")
1146            dbhost = params[0]
1147            dbport = int(params[1])
1148        except (AttributeError, IndexError, TypeError, ValueError):
1149            pass
1150
1151    # empty host is localhost
1152    if dbhost == "":
1153        dbhost = None
1154    if dbuser == "":
1155        dbuser = None
1156
1157    # open the connection
1158    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
1159    return Connection(cnx)
1160
1161
1162### Types Handling
1163
1164class Type(frozenset):
1165    """Type class for a couple of PostgreSQL data types.
1166
1167    PostgreSQL is object-oriented: types are dynamic.
1168    We must thus use type names as internal type codes.
1169    """
1170
1171    def __new__(cls, values):
1172        if isinstance(values, basestring):
1173            values = values.split()
1174        return super(Type, cls).__new__(cls, values)
1175
1176    def __eq__(self, other):
1177        if isinstance(other, basestring):
1178            if other.startswith('_'):
1179                other = other[1:]
1180            return other in self
1181        else:
1182            return super(Type, self).__eq__(other)
1183
1184    def __ne__(self, other):
1185        if isinstance(other, basestring):
1186            if other.startswith('_'):
1187                other = other[1:]
1188            return other not in self
1189        else:
1190            return super(Type, self).__ne__(other)
1191
1192
1193class ArrayType:
1194    """Type class for PostgreSQL array types."""
1195
1196    def __eq__(self, other):
1197        if isinstance(other, basestring):
1198            return other.startswith('_')
1199        else:
1200            return isinstance(other, ArrayType)
1201
1202    def __ne__(self, other):
1203        if isinstance(other, basestring):
1204            return not other.startswith('_')
1205        else:
1206            return not isinstance(other, ArrayType)
1207
1208
1209class RecordType:
1210    """Type class for PostgreSQL record types."""
1211
1212    def __eq__(self, other):
1213        if isinstance(other, TypeCode):
1214            return other.type == 'c'
1215        elif isinstance(other, basestring):
1216            return other == 'record'
1217        else:
1218            return isinstance(other, RecordType)
1219
1220    def __ne__(self, other):
1221        if isinstance(other, TypeCode):
1222            return other.type != 'c'
1223        elif isinstance(other, basestring):
1224            return other != 'record'
1225        else:
1226            return not isinstance(other, RecordType)
1227
1228
1229# Mandatory type objects defined by DB-API 2 specs:
1230
1231STRING = Type('char bpchar name text varchar')
1232BINARY = Type('bytea')
1233NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1234DATETIME = Type('date time timetz timestamp timestamptz interval'
1235    ' abstime reltime')  # these are very old
1236ROWID = Type('oid')
1237
1238
1239# Additional type objects (more specific):
1240
1241BOOL = Type('bool')
1242SMALLINT = Type('int2')
1243INTEGER = Type('int2 int4 int8 serial')
1244LONG = Type('int8')
1245FLOAT = Type('float4 float8')
1246NUMERIC = Type('numeric')
1247MONEY = Type('money')
1248DATE = Type('date')
1249TIME = Type('time timetz')
1250TIMESTAMP = Type('timestamp timestamptz')
1251INTERVAL = Type('interval')
1252JSON = Type('json jsonb')
1253
1254# Type object for arrays (also equate to their base types):
1255
1256ARRAY = ArrayType()
1257
1258# Type object for records (encompassing all composite types):
1259
1260RECORD = RecordType()
1261
1262
1263# Mandatory type helpers defined by DB-API 2 specs:
1264
1265def Date(year, month, day):
1266    """Construct an object holding a date value."""
1267    return date(year, month, day)
1268
1269
1270def Time(hour, minute=0, second=0, microsecond=0):
1271    """Construct an object holding a time value."""
1272    return time(hour, minute, second, microsecond)
1273
1274
1275def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1276    """Construct an object holding a time stamp value."""
1277    return datetime(year, month, day, hour, minute, second, microsecond)
1278
1279
1280def DateFromTicks(ticks):
1281    """Construct an object holding a date value from the given ticks value."""
1282    return Date(*localtime(ticks)[:3])
1283
1284
1285def TimeFromTicks(ticks):
1286    """Construct an object holding a time value from the given ticks value."""
1287    return Time(*localtime(ticks)[3:6])
1288
1289
1290def TimestampFromTicks(ticks):
1291    """Construct an object holding a time stamp from the given ticks value."""
1292    return Timestamp(*localtime(ticks)[:6])
1293
1294
1295class Binary(bytes):
1296    """Construct an object capable of holding a binary (long) string value."""
1297
1298
1299# Additional type helpers for PyGreSQL:
1300
1301class Bytea(bytes):
1302    """Construct an object capable of holding a bytea value."""
1303
1304
1305class Json:
1306    """Construct a wrapper for holding an object serializable to JSON."""
1307
1308    def __init__(self, obj, encode=None):
1309        self.obj = obj
1310        self.encode = encode or jsonencode
1311
1312    def __str__(self):
1313        obj = self.obj
1314        if isinstance(obj, basestring):
1315            return obj
1316        return self.encode(obj)
1317
1318    __pg_repr__ = __str__
1319
1320
1321class Literal:
1322    """Construct a wrapper for holding a literal SQL string."""
1323
1324    def __init__(self, sql):
1325        self.sql = sql
1326
1327    def __str__(self):
1328        return self.sql
1329
1330    __pg_repr__ = __str__
1331
1332
1333# If run as script, print some information:
1334
1335if __name__ == '__main__':
1336    print('PyGreSQL version', version)
1337    print('')
1338    print(__doc__)
Note: See TracBrowser for help on using the repository browser.