source: trunk/pgdb.py @ 796

Last change on this file since 796 was 796, checked in by cito, 3 years ago

Enrich type_code strings in the DB API 2

The type codes now carry e.g. the information whether a type is a record.
This allows to provide a RECORD type object that compares equal to all
kinds of records, similar to the already existing ARRAY type object.

  • Property svn:keywords set to Id
File size: 38.2 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 796 2016-01-28 21:25:36Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from json import loads as jsondecode, dumps as jsonencode
76
77try:
78    long
79except NameError:  # Python >= 3.0
80    long = int
81
82try:
83    unicode
84except NameError:  # Python >= 3.0
85    unicode = str
86
87try:
88    basestring
89except NameError:  # Python >= 3.0
90    basestring = (str, bytes)
91
92from collections import Iterable
93try:
94    from collections import OrderedDict
95except ImportError:  # Python 2.6 or 3.0
96    try:
97        from ordereddict import OrderedDict
98    except Exception:
99        def OrderedDict(*args):
100            raise NotSupportedError('OrderedDict is not supported')
101
102
103### Module Constants
104
105# compliant with DB API 2.0
106apilevel = '2.0'
107
108# module may be shared, but not connections
109threadsafety = 1
110
111# this module use extended python format codes
112paramstyle = 'pyformat'
113
114# shortcut methods have been excluded from DB API 2 and
115# are not recommended by the DB SIG, but they can be handy
116shortcutmethods = 1
117
118
119### Internal Types Handling
120
121def decimal_type(decimal_type=None):
122    """Get or set global type to be used for decimal values."""
123    global Decimal
124    if decimal_type is not None:
125        _cast['numeric'] = Decimal = decimal_type
126    return Decimal
127
128
129def _cast_bool(value):
130    return value[:1] in ('t', 'T')
131
132
133def _cast_money(value):
134    return Decimal(''.join(filter(
135        lambda v: v in '0123456789.-', value)))
136
137
138_cast = {'char': str, 'bpchar': str, 'name': str,
139    'text': str, 'varchar': str,
140    'bool': _cast_bool, 'bytea': unescape_bytea,
141    'int2': int, 'int4': int, 'serial': int,
142    'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
143    'oid': long, 'oid8': long,
144    'float4': float, 'float8': float,
145    'numeric': Decimal, 'money': _cast_money,
146    'record': cast_record}
147
148
149def _db_error(msg, cls=DatabaseError):
150    """Return DatabaseError with empty sqlstate attribute."""
151    error = cls(msg)
152    error.sqlstate = None
153    return error
154
155
156def _op_error(msg):
157    """Return OperationalError."""
158    return _db_error(msg, OperationalError)
159
160
161class TypeCode(str):
162    """Class representing the type_code used by the DB-API 2.0.
163
164    TypeCode objects are strings equal to the PostgreSQL type name,
165    but carry some additional information.
166    """
167
168    @classmethod
169    def create(cls, oid, name, len, type, category, delim, relid):
170        """Create a type code for a PostgreSQL data type."""
171        self = cls(name)
172        self.oid = oid
173        self.len = len
174        self.type = type
175        self.category = category
176        self.delim = delim
177        self.relid = relid
178        return self
179
180ColumnInfo = namedtuple('ColumnInfo', ['name', 'type'])
181
182
183class TypeCache(dict):
184    """Cache for database types.
185
186    This cache maps type OIDs and names to TypeCode strings containing
187    important information on the associated database type.
188    """
189
190    def __init__(self, cnx):
191        """Initialize type cache for connection."""
192        super(TypeCache, self).__init__()
193        self._escape_string = cnx.escape_string
194        self._src = cnx.source()
195
196    def __missing__(self, key):
197        """Get the type info from the database if it is not cached."""
198        if isinstance(key, int):
199            oid = key
200        else:
201            if not '.' in key and not '"' in key:
202                key = '"%s"' % key
203            oid = "'%s'::regtype" % self._escape_string(key)
204        try:
205            self._src.execute("SELECT oid, typname,"
206                 " typlen, typtype, typcategory, typdelim, typrelid"
207                " FROM pg_type WHERE oid=%s" % oid)
208        except ProgrammingError:
209            res = None
210        else:
211            res = self._src.fetch(1)
212        if not res:
213            raise KeyError('Type %s could not be found' % key)
214        res = res[0]
215        type_code = TypeCode.create(int(res[0]), res[1],
216            int(res[2]), res[3], res[4], res[5], int(res[6]))
217        self[type_code.oid] = self[str(type_code)] = type_code
218        return type_code
219
220    def get(self, key, default=None):
221        """Get the type even if it is not cached."""
222        try:
223            return self[key]
224        except KeyError:
225            return default
226
227    def columns(self, key):
228        """Get the names and types of the columns of composite types."""
229        try:
230            typ = self[key]
231        except KeyError:
232            return None  # this type is not known
233        if typ.type != 'c' or not typ.relid:
234            return None  # this type is not composite
235        self._src.execute("SELECT attname, atttypid"
236            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
237            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
238        return [ColumnInfo(name, int(oid))
239            for name, oid in self._src.fetch(-1)]
240
241    def typecast(self, typ, value):
242        """Cast value according to database type."""
243        if value is None:
244            # for NULL values, no typecast is necessary
245            return None
246        cast = _cast.get(typ)
247        if cast is str:
248            return value  # no typecast necessary
249        if cast is None:
250            if typ.startswith('_'):
251                # cast as an array type
252                cast = _cast.get(typ[1:])
253                return cast_array(value, cast)
254            # check whether this is a composite type
255            cols = self.columns(typ)
256            if cols:
257                getcast = self.getcast
258                cast = [getcast(col.type) for col in cols]
259                value = cast_record(value, cast)
260                fields = [col.name for col in cols]
261                record = namedtuple(typ, fields)
262                return record(*value)
263            return value  # no typecast available or necessary
264        else:
265            return cast(value)
266
267    def getcast(self, key):
268        """Get a cast function for the given database type."""
269        if isinstance(key, int):
270            try:
271                typ = self[key]
272            except KeyError:
273                return None
274        else:
275            typ = key
276        typecast = self.typecast
277        return lambda value: typecast(typ, value)
278
279
280class _quotedict(dict):
281    """Dictionary with auto quoting of its items.
282
283    The quote attribute must be set to the desired quote function.
284    """
285
286    def __getitem__(self, key):
287        return self.quote(super(_quotedict, self).__getitem__(key))
288
289
290### Cursor Object
291
292class Cursor(object):
293    """Cursor object."""
294
295    def __init__(self, dbcnx):
296        """Create a cursor object for the database connection."""
297        self.connection = self._dbcnx = dbcnx
298        self._cnx = dbcnx._cnx
299        self.type_cache = dbcnx.type_cache
300        self._src = self._cnx.source()
301        # the official attribute for describing the result columns
302        self._description = None
303        if self.row_factory is Cursor.row_factory:
304            # the row factory needs to be determined dynamically
305            self.row_factory = None
306        else:
307            self.build_row_factory = None
308        self.rowcount = -1
309        self.arraysize = 1
310        self.lastrowid = None
311
312    def __iter__(self):
313        """Make cursor compatible to the iteration protocol."""
314        return self
315
316    def __enter__(self):
317        """Enter the runtime context for the cursor object."""
318        return self
319
320    def __exit__(self, et, ev, tb):
321        """Exit the runtime context for the cursor object."""
322        self.close()
323
324    def _quote(self, val):
325        """Quote value depending on its type."""
326        if val is None:
327            return 'NULL'
328        if isinstance(val, (datetime, date, time, timedelta, Json)):
329            val = str(val)
330        if isinstance(val, basestring):
331            if isinstance(val, Binary):
332                val = self._cnx.escape_bytea(val)
333                if bytes is not str:  # Python >= 3.0
334                    val = val.decode('ascii')
335            else:
336                val = self._cnx.escape_string(val)
337            return "'%s'" % val
338        if isinstance(val, float):
339            if isinf(val):
340                return "'-Infinity'" if val < 0 else "'Infinity'"
341            if isnan(val):
342                return "'NaN'"
343            return val
344        if isinstance(val, (int, long, Decimal)):
345            return val
346        if isinstance(val, list):
347            # Quote value as an ARRAY constructor. This is better than using
348            # an array literal because it carries the information that this is
349            # an array and not a string.  One issue with this syntax is that
350            # you need to add an explicit type cast when passing empty arrays.
351            # The ARRAY keyword is actually only necessary at the top level.
352            q = self._quote
353            return 'ARRAY[%s]' % ','.join(str(q(v)) for v in val)
354        if isinstance(val, tuple):
355            # Quote as a ROW constructor.  This is better than using a record
356            # literal because it carries the information that this is a record
357            # and not a string.  We don't use the keyword ROW in order to make
358            # this usable with the IN synntax as well.  It is only necessary
359            # when the records has a single column which is not really useful.
360            q = self._quote
361            return '(%s)' % ','.join(str(q(v)) for v in val)
362        try:
363            return val.__pg_repr__()
364        except AttributeError:
365            raise InterfaceError(
366                'do not know how to handle type %s' % type(val))
367
368
369    def _quoteparams(self, string, parameters):
370        """Quote parameters.
371
372        This function works for both mappings and sequences.
373        """
374        if isinstance(parameters, dict):
375            parameters = _quotedict(parameters)
376            parameters.quote = self._quote
377        else:
378            parameters = tuple(map(self._quote, parameters))
379        return string % parameters
380
381    def _make_description(self, info):
382        """Make the description tuple for the given field info."""
383        name, typ, size, mod = info[1:]
384        type_code = self.type_cache[typ]
385        if mod > 0:
386            mod -= 4
387        if type_code == 'numeric':
388            precision, scale = mod >> 16, mod & 0xffff
389            size = precision
390        else:
391            if not size:
392                size = type_info.size
393            if size == -1:
394                size = mod
395            precision = scale = None
396        return CursorDescription(name, type_code,
397            None, size, precision, scale, None)
398
399    @property
400    def description(self):
401        """Read-only attribute describing the result columns."""
402        descr = self._description
403        if self._description is True:
404            make = self._make_description
405            descr = [make(info) for info in self._src.listinfo()]
406            self._description = descr
407        return descr
408
409    @property
410    def colnames(self):
411        """Unofficial convenience method for getting the column names."""
412        return [d[0] for d in self.description]
413
414    @property
415    def coltypes(self):
416        """Unofficial convenience method for getting the column types."""
417        return [d[1] for d in self.description]
418
419    def close(self):
420        """Close the cursor object."""
421        self._src.close()
422        self._description = None
423        self.rowcount = -1
424        self.lastrowid = None
425
426    def execute(self, operation, parameters=None):
427        """Prepare and execute a database operation (query or command)."""
428        # The parameters may also be specified as list of tuples to e.g.
429        # insert multiple rows in a single operation, but this kind of
430        # usage is deprecated.  We make several plausibility checks because
431        # tuples can also be passed with the meaning of ROW constructors.
432        if (parameters and isinstance(parameters, list)
433                and len(parameters) > 1
434                and all(isinstance(p, tuple) for p in parameters)
435                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
436            return self.executemany(operation, parameters)
437        else:
438            # not a list of tuples
439            return self.executemany(operation, [parameters])
440
441    def executemany(self, operation, seq_of_parameters):
442        """Prepare operation and execute it against a parameter sequence."""
443        if not seq_of_parameters:
444            # don't do anything without parameters
445            return
446        self._description = None
447        self.rowcount = -1
448        # first try to execute all queries
449        rowcount = 0
450        sql = "BEGIN"
451        try:
452            if not self._dbcnx._tnx:
453                try:
454                    self._cnx.source().execute(sql)
455                except DatabaseError:
456                    raise  # database provides error message
457                except Exception as err:
458                    raise _op_error("can't start transaction")
459                self._dbcnx._tnx = True
460            for parameters in seq_of_parameters:
461                sql = operation
462                if parameters:
463                    sql = self._quoteparams(sql, parameters)
464                rows = self._src.execute(sql)
465                if rows:  # true if not DML
466                    rowcount += rows
467                else:
468                    self.rowcount = -1
469        except DatabaseError:
470            raise  # database provides error message
471        except Error as err:
472            raise _db_error(
473                "error in '%s': '%s' " % (sql, err), InterfaceError)
474        except Exception as err:
475            raise _op_error("internal error in '%s': %s" % (sql, err))
476        # then initialize result raw count and description
477        if self._src.resulttype == RESULT_DQL:
478            self._description = True  # fetch on demand
479            self.rowcount = self._src.ntuples
480            self.lastrowid = None
481            if self.build_row_factory:
482                self.row_factory = self.build_row_factory()
483        else:
484            self.rowcount = rowcount
485            self.lastrowid = self._src.oidstatus()
486        # return the cursor object, so you can write statements such as
487        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
488        return self
489
490    def fetchone(self):
491        """Fetch the next row of a query result set."""
492        res = self.fetchmany(1, False)
493        try:
494            return res[0]
495        except IndexError:
496            return None
497
498    def fetchall(self):
499        """Fetch all (remaining) rows of a query result."""
500        return self.fetchmany(-1, False)
501
502    def fetchmany(self, size=None, keep=False):
503        """Fetch the next set of rows of a query result.
504
505        The number of rows to fetch per call is specified by the
506        size parameter. If it is not given, the cursor's arraysize
507        determines the number of rows to be fetched. If you set
508        the keep parameter to true, this is kept as new arraysize.
509        """
510        if size is None:
511            size = self.arraysize
512        if keep:
513            self.arraysize = size
514        try:
515            result = self._src.fetch(size)
516        except DatabaseError:
517            raise
518        except Error as err:
519            raise _db_error(str(err))
520        typecast = self.type_cache.typecast
521        return [self.row_factory([typecast(typ, value)
522            for typ, value in zip(self.coltypes, row)]) for row in result]
523
524    def callproc(self, procname, parameters=None):
525        """Call a stored database procedure with the given name.
526
527        The sequence of parameters must contain one entry for each input
528        argument that the procedure expects. The result of the call is the
529        same as this input sequence; replacement of output and input/output
530        parameters in the return value is currently not supported.
531
532        The procedure may also provide a result set as output. These can be
533        requested through the standard fetch methods of the cursor.
534        """
535        n = parameters and len(parameters) or 0
536        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
537        self.execute(query, parameters)
538        return parameters
539
540    def copy_from(self, stream, table,
541            format=None, sep=None, null=None, size=None, columns=None):
542        """Copy data from an input stream to the specified table.
543
544        The input stream can be a file-like object with a read() method or
545        it can also be an iterable returning a row or multiple rows of input
546        on each iteration.
547
548        The format must be text, csv or binary. The sep option sets the
549        column separator (delimiter) used in the non binary formats.
550        The null option sets the textual representation of NULL in the input.
551
552        The size option sets the size of the buffer used when reading data
553        from file-like objects.
554
555        The copy operation can be restricted to a subset of columns. If no
556        columns are specified, all of them will be copied.
557        """
558        binary_format = format == 'binary'
559        try:
560            read = stream.read
561        except AttributeError:
562            if size:
563                raise ValueError("size must only be set for file-like objects")
564            if binary_format:
565                input_type = bytes
566                type_name = 'byte strings'
567            else:
568                input_type = basestring
569                type_name = 'strings'
570
571            if isinstance(stream, basestring):
572                if not isinstance(stream, input_type):
573                    raise ValueError("The input must be %s" % type_name)
574                if not binary_format:
575                    if isinstance(stream, str):
576                        if not stream.endswith('\n'):
577                            stream += '\n'
578                    else:
579                        if not stream.endswith(b'\n'):
580                            stream += b'\n'
581
582                def chunks():
583                    yield stream
584
585            elif isinstance(stream, Iterable):
586
587                def chunks():
588                    for chunk in stream:
589                        if not isinstance(chunk, input_type):
590                            raise ValueError(
591                                "Input stream must consist of %s" % type_name)
592                        if isinstance(chunk, str):
593                            if not chunk.endswith('\n'):
594                                chunk += '\n'
595                        else:
596                            if not chunk.endswith(b'\n'):
597                                chunk += b'\n'
598                        yield chunk
599
600            else:
601                raise TypeError("Need an input stream to copy from")
602        else:
603            if size is None:
604                size = 8192
605            elif not isinstance(size, int):
606                raise TypeError("The size option must be an integer")
607            if size > 0:
608
609                def chunks():
610                    while True:
611                        buffer = read(size)
612                        yield buffer
613                        if not buffer or len(buffer) < size:
614                            break
615
616            else:
617
618                def chunks():
619                    yield read()
620
621        if not table or not isinstance(table, basestring):
622            raise TypeError("Need a table to copy to")
623        if table.lower().startswith('select'):
624                raise ValueError("Must specify a table, not a query")
625        else:
626            table = '"%s"' % (table,)
627        operation = ['copy %s' % (table,)]
628        options = []
629        params = []
630        if format is not None:
631            if not isinstance(format, basestring):
632                raise TypeError("format option must be be a string")
633            if format not in ('text', 'csv', 'binary'):
634                raise ValueError("invalid format")
635            options.append('format %s' % (format,))
636        if sep is not None:
637            if not isinstance(sep, basestring):
638                raise TypeError("sep option must be a string")
639            if format == 'binary':
640                raise ValueError("sep is not allowed with binary format")
641            if len(sep) != 1:
642                raise ValueError("sep must be a single one-byte character")
643            options.append('delimiter %s')
644            params.append(sep)
645        if null is not None:
646            if not isinstance(null, basestring):
647                raise TypeError("null option must be a string")
648            options.append('null %s')
649            params.append(null)
650        if columns:
651            if not isinstance(columns, basestring):
652                columns = ','.join('"%s"' % (col,) for col in columns)
653            operation.append('(%s)' % (columns,))
654        operation.append("from stdin")
655        if options:
656            operation.append('(%s)' % ','.join(options))
657        operation = ' '.join(operation)
658
659        putdata = self._src.putdata
660        self.execute(operation, params)
661
662        try:
663            for chunk in chunks():
664                putdata(chunk)
665        except BaseException as error:
666            self.rowcount = -1
667            # the following call will re-raise the error
668            putdata(error)
669        else:
670            self.rowcount = putdata(None)
671
672        # return the cursor object, so you can chain operations
673        return self
674
675    def copy_to(self, stream, table,
676            format=None, sep=None, null=None, decode=None, columns=None):
677        """Copy data from the specified table to an output stream.
678
679        The output stream can be a file-like object with a write() method or
680        it can also be None, in which case the method will return a generator
681        yielding a row on each iteration.
682
683        Output will be returned as byte strings unless you set decode to true.
684
685        Note that you can also use a select query instead of the table name.
686
687        The format must be text, csv or binary. The sep option sets the
688        column separator (delimiter) used in the non binary formats.
689        The null option sets the textual representation of NULL in the output.
690
691        The copy operation can be restricted to a subset of columns. If no
692        columns are specified, all of them will be copied.
693        """
694        binary_format = format == 'binary'
695        if stream is not None:
696            try:
697                write = stream.write
698            except AttributeError:
699                raise TypeError("need an output stream to copy to")
700        if not table or not isinstance(table, basestring):
701            raise TypeError("need a table to copy to")
702        if table.lower().startswith('select'):
703            if columns:
704                raise ValueError("columns must be specified in the query")
705            table = '(%s)' % (table,)
706        else:
707            table = '"%s"' % (table,)
708        operation = ['copy %s' % (table,)]
709        options = []
710        params = []
711        if format is not None:
712            if not isinstance(format, basestring):
713                raise TypeError("format option must be a string")
714            if format not in ('text', 'csv', 'binary'):
715                raise ValueError("invalid format")
716            options.append('format %s' % (format,))
717        if sep is not None:
718            if not isinstance(sep, basestring):
719                raise TypeError("sep option must be a string")
720            if binary_format:
721                raise ValueError("sep is not allowed with binary format")
722            if len(sep) != 1:
723                raise ValueError("sep must be a single one-byte character")
724            options.append('delimiter %s')
725            params.append(sep)
726        if null is not None:
727            if not isinstance(null, basestring):
728                raise TypeError("null option must be a string")
729            options.append('null %s')
730            params.append(null)
731        if decode is None:
732            if format == 'binary':
733                decode = False
734            else:
735                decode = str is unicode
736        else:
737            if not isinstance(decode, (int, bool)):
738                raise TypeError("decode option must be a boolean")
739            if decode and binary_format:
740                raise ValueError("decode is not allowed with binary format")
741        if columns:
742            if not isinstance(columns, basestring):
743                columns = ','.join('"%s"' % (col,) for col in columns)
744            operation.append('(%s)' % (columns,))
745
746        operation.append("to stdout")
747        if options:
748            operation.append('(%s)' % ','.join(options))
749        operation = ' '.join(operation)
750
751        getdata = self._src.getdata
752        self.execute(operation, params)
753
754        def copy():
755            self.rowcount = 0
756            while True:
757                row = getdata(decode)
758                if isinstance(row, int):
759                    if self.rowcount != row:
760                        self.rowcount = row
761                    break
762                self.rowcount += 1
763                yield row
764
765        if stream is None:
766            # no input stream, return the generator
767            return copy()
768
769        # write the rows to the file-like input stream
770        for row in copy():
771            write(row)
772
773        # return the cursor object, so you can chain operations
774        return self
775
776    def __next__(self):
777        """Return the next row (support for the iteration protocol)."""
778        res = self.fetchone()
779        if res is None:
780            raise StopIteration
781        return res
782
783    # Note that since Python 2.6 the iterator protocol uses __next()__
784    # instead of next(), we keep it only for backward compatibility of pgdb.
785    next = __next__
786
787    @staticmethod
788    def nextset():
789        """Not supported."""
790        raise NotSupportedError("nextset() is not supported")
791
792    @staticmethod
793    def setinputsizes(sizes):
794        """Not supported."""
795        pass  # unsupported, but silently passed
796
797    @staticmethod
798    def setoutputsize(size, column=0):
799        """Not supported."""
800        pass  # unsupported, but silently passed
801
802    @staticmethod
803    def row_factory(row):
804        """Process rows before they are returned.
805
806        You can overwrite this statically with a custom row factory, or
807        you can build a row factory dynamically with build_row_factory().
808
809        For example, you can create a Cursor class that returns rows as
810        Python dictionaries like this:
811
812            class DictCursor(pgdb.Cursor):
813
814                def row_factory(self, row):
815                    return {desc[0]: value
816                        for desc, value in zip(self.description, row)}
817
818            cur = DictCursor(con)  # get one DictCursor instance or
819            con.cursor_type = DictCursor  # always use DictCursor instances
820        """
821        raise NotImplementedError
822
823    def build_row_factory(self):
824        """Build a row factory based on the current description.
825
826        This implementation builds a row factory for creating named tuples.
827        You can overwrite this method if you want to dynamically create
828        different row factories whenever the column description changes.
829        """
830        colnames = self.colnames
831        if colnames:
832            try:
833                try:
834                    return namedtuple('Row', colnames, rename=True)._make
835                except TypeError:  # Python 2.6 and 3.0 do not support rename
836                    colnames = [v if v.isalnum() else 'column_%d' % n
837                             for n, v in enumerate(colnames)]
838                    return namedtuple('Row', colnames)._make
839            except ValueError:  # there is still a problem with the field names
840                colnames = ['column_%d' % n for n in range(len(colnames))]
841                return namedtuple('Row', colnames)._make
842
843
844CursorDescription = namedtuple('CursorDescription',
845    ['name', 'type_code', 'display_size', 'internal_size',
846     'precision', 'scale', 'null_ok'])
847
848
849### Connection Objects
850
851class Connection(object):
852    """Connection object."""
853
854    # expose the exceptions as attributes on the connection object
855    Error = Error
856    Warning = Warning
857    InterfaceError = InterfaceError
858    DatabaseError = DatabaseError
859    InternalError = InternalError
860    OperationalError = OperationalError
861    ProgrammingError = ProgrammingError
862    IntegrityError = IntegrityError
863    DataError = DataError
864    NotSupportedError = NotSupportedError
865
866    def __init__(self, cnx):
867        """Create a database connection object."""
868        self._cnx = cnx  # connection
869        self._tnx = False  # transaction state
870        self.type_cache = TypeCache(cnx)
871        self.cursor_type = Cursor
872        try:
873            self._cnx.source()
874        except Exception:
875            raise _op_error("invalid connection")
876
877    def __enter__(self):
878        """Enter the runtime context for the connection object.
879
880        The runtime context can be used for running transactions.
881        """
882        return self
883
884    def __exit__(self, et, ev, tb):
885        """Exit the runtime context for the connection object.
886
887        This does not close the connection, but it ends a transaction.
888        """
889        if et is None and ev is None and tb is None:
890            self.commit()
891        else:
892            self.rollback()
893
894    def close(self):
895        """Close the connection object."""
896        if self._cnx:
897            if self._tnx:
898                try:
899                    self.rollback()
900                except DatabaseError:
901                    pass
902            self._cnx.close()
903            self._cnx = None
904        else:
905            raise _op_error("connection has been closed")
906
907    def commit(self):
908        """Commit any pending transaction to the database."""
909        if self._cnx:
910            if self._tnx:
911                self._tnx = False
912                try:
913                    self._cnx.source().execute("COMMIT")
914                except DatabaseError:
915                    raise
916                except Exception:
917                    raise _op_error("can't commit")
918        else:
919            raise _op_error("connection has been closed")
920
921    def rollback(self):
922        """Roll back to the start of any pending transaction."""
923        if self._cnx:
924            if self._tnx:
925                self._tnx = False
926                try:
927                    self._cnx.source().execute("ROLLBACK")
928                except DatabaseError:
929                    raise
930                except Exception:
931                    raise _op_error("can't rollback")
932        else:
933            raise _op_error("connection has been closed")
934
935    def cursor(self):
936        """Return a new cursor object using the connection."""
937        if self._cnx:
938            try:
939                return self.cursor_type(self)
940            except Exception:
941                raise _op_error("invalid connection")
942        else:
943            raise _op_error("connection has been closed")
944
945    if shortcutmethods:  # otherwise do not implement and document this
946
947        def execute(self, operation, params=None):
948            """Shortcut method to run an operation on an implicit cursor."""
949            cursor = self.cursor()
950            cursor.execute(operation, params)
951            return cursor
952
953        def executemany(self, operation, param_seq):
954            """Shortcut method to run an operation against a sequence."""
955            cursor = self.cursor()
956            cursor.executemany(operation, param_seq)
957            return cursor
958
959
960### Module Interface
961
962_connect_ = connect
963
964def connect(dsn=None,
965        user=None, password=None,
966        host=None, database=None):
967    """Connect to a database."""
968    # first get params from DSN
969    dbport = -1
970    dbhost = ""
971    dbbase = ""
972    dbuser = ""
973    dbpasswd = ""
974    dbopt = ""
975    try:
976        params = dsn.split(":")
977        dbhost = params[0]
978        dbbase = params[1]
979        dbuser = params[2]
980        dbpasswd = params[3]
981        dbopt = params[4]
982    except (AttributeError, IndexError, TypeError):
983        pass
984
985    # override if necessary
986    if user is not None:
987        dbuser = user
988    if password is not None:
989        dbpasswd = password
990    if database is not None:
991        dbbase = database
992    if host is not None:
993        try:
994            params = host.split(":")
995            dbhost = params[0]
996            dbport = int(params[1])
997        except (AttributeError, IndexError, TypeError, ValueError):
998            pass
999
1000    # empty host is localhost
1001    if dbhost == "":
1002        dbhost = None
1003    if dbuser == "":
1004        dbuser = None
1005
1006    # open the connection
1007    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
1008    return Connection(cnx)
1009
1010
1011### Types Handling
1012
1013class Type(frozenset):
1014    """Type class for a couple of PostgreSQL data types.
1015
1016    PostgreSQL is object-oriented: types are dynamic.
1017    We must thus use type names as internal type codes.
1018    """
1019
1020    def __new__(cls, values):
1021        if isinstance(values, basestring):
1022            values = values.split()
1023        return super(Type, cls).__new__(cls, values)
1024
1025    def __eq__(self, other):
1026        if isinstance(other, basestring):
1027            if other.startswith('_'):
1028                other = other[1:]
1029            return other in self
1030        else:
1031            return super(Type, self).__eq__(other)
1032
1033    def __ne__(self, other):
1034        if isinstance(other, basestring):
1035            if other.startswith('_'):
1036                other = other[1:]
1037            return other not in self
1038        else:
1039            return super(Type, self).__ne__(other)
1040
1041
1042class ArrayType:
1043    """Type class for PostgreSQL array types."""
1044
1045    def __eq__(self, other):
1046        if isinstance(other, basestring):
1047            return other.startswith('_')
1048        else:
1049            return isinstance(other, ArrayType)
1050
1051    def __ne__(self, other):
1052        if isinstance(other, basestring):
1053            return not other.startswith('_')
1054        else:
1055            return not isinstance(other, ArrayType)
1056
1057
1058class RecordType:
1059    """Type class for PostgreSQL record types."""
1060
1061    def __eq__(self, other):
1062        if isinstance(other, TypeCode):
1063            return other.type == 'c'
1064        elif isinstance(other, basestring):
1065            return other == 'record'
1066        else:
1067            return isinstance(other, RecordType)
1068
1069    def __ne__(self, other):
1070        if isinstance(other, TypeCode):
1071            return other.type != 'c'
1072        elif isinstance(other, basestring):
1073            return other != 'record'
1074        else:
1075            return not isinstance(other, RecordType)
1076
1077
1078# Mandatory type objects defined by DB-API 2 specs:
1079
1080STRING = Type('char bpchar name text varchar')
1081BINARY = Type('bytea')
1082NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1083DATETIME = Type('date time timetz timestamp timestamptz interval'
1084    ' abstime reltime')  # these are very old
1085ROWID = Type('oid')
1086
1087
1088# Additional type objects (more specific):
1089
1090BOOL = Type('bool')
1091SMALLINT = Type('int2')
1092INTEGER = Type('int2 int4 int8 serial')
1093LONG = Type('int8')
1094FLOAT = Type('float4 float8')
1095NUMERIC = Type('numeric')
1096MONEY = Type('money')
1097DATE = Type('date')
1098TIME = Type('time timetz')
1099TIMESTAMP = Type('timestamp timestamptz')
1100INTERVAL = Type('interval')
1101JSON = Type('json jsonb')
1102
1103# Type object for arrays (also equate to their base types):
1104
1105ARRAY = ArrayType()
1106
1107# Type object for records (encompassing all composite types):
1108
1109RECORD = RecordType()
1110
1111
1112# Mandatory type helpers defined by DB-API 2 specs:
1113
1114def Date(year, month, day):
1115    """Construct an object holding a date value."""
1116    return date(year, month, day)
1117
1118
1119def Time(hour, minute=0, second=0, microsecond=0):
1120    """Construct an object holding a time value."""
1121    return time(hour, minute, second, microsecond)
1122
1123
1124def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1125    """Construct an object holding a time stamp value."""
1126    return datetime(year, month, day, hour, minute, second, microsecond)
1127
1128
1129def DateFromTicks(ticks):
1130    """Construct an object holding a date value from the given ticks value."""
1131    return Date(*localtime(ticks)[:3])
1132
1133
1134def TimeFromTicks(ticks):
1135    """Construct an object holding a time value from the given ticks value."""
1136    return Time(*localtime(ticks)[3:6])
1137
1138
1139def TimestampFromTicks(ticks):
1140    """Construct an object holding a time stamp from the given ticks value."""
1141    return Timestamp(*localtime(ticks)[:6])
1142
1143
1144class Binary(bytes):
1145    """Construct an object capable of holding a binary (long) string value."""
1146
1147
1148# Additional type helpers for PyGreSQL:
1149
1150class Json:
1151    """Construct a wrapper for holding an object serializable to JSON."""
1152
1153    def __init__(self, obj, encode=None):
1154        self.obj = obj
1155        self.encode = encode or jsonencode
1156
1157    def __str__(self):
1158        obj = self.obj
1159        if isinstance(obj, basestring):
1160            return obj
1161        return self.encode(obj)
1162
1163    __pg_repr__ = __str__
1164
1165
1166# If run as script, print some information:
1167
1168if __name__ == '__main__':
1169    print('PyGreSQL version', version)
1170    print('')
1171    print(__doc__)
Note: See TracBrowser for help on using the repository browser.