source: trunk/pgdb.py @ 788

Last change on this file since 788 was 788, checked in by cito, 4 years ago

Add method columns() to type cache

The method returns the columns of composite types.
This makes the type cache even more useful.

  • Property svn:keywords set to Id
File size: 36.1 KB
RevLine 
[521]1#! /usr/bin/python
[222]2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
[404]7# $Id: pgdb.py 788 2016-01-26 21:13:55Z cito $
[222]8#
[184]9
[205]10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
[27]11
[222]12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
[27]14
[222]15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
[27]18
[222]19Basic usage:
[27]20
[346]21    pgdb.connect(connect_string) # open a connection
[680]22    # connect_string = 'host:database:user:password:opt'
[346]23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
[522]26    connection = pgdb.connect(host='localhost:5432')
[27]27
[522]28    cursor = connection.cursor() # open a cursor
[27]29
[346]30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
[27]34
[346]35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
[27]38
[346]39    cursor.fetchone() # fetch one row, [value, value, ...]
[27]40
[346]41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
[27]42
[346]43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
[27]47
[346]48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
[681]51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
[27]53
[346]54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
[27]56
[346]57    connection.commit() # commit transaction
[27]58
[346]59    connection.rollback() # or rollback transaction
[27]60
[346]61    cursor.close() # close the cursor
[27]62
[346]63    connection.close() # close the connection
[27]64"""
65
[627]66from __future__ import print_function
67
[328]68from _pg import *
[553]69
[423]70from datetime import date, time, datetime, timedelta
71from time import localtime
[553]72from decimal import Decimal
73from math import isnan, isinf
[681]74from collections import namedtuple
[781]75from re import compile as regex
[774]76from json import loads as jsondecode, dumps as jsonencode
[49]77
[599]78try:
79    long
80except NameError:  # Python >= 3.0
81    long = int
82
83try:
[629]84    unicode
85except NameError:  # Python >= 3.0
86    unicode = str
87
88try:
[599]89    basestring
90except NameError:  # Python >= 3.0
91    basestring = (str, bytes)
92
[693]93from collections import Iterable
[682]94try:
95    from collections import OrderedDict
96except ImportError:  # Python 2.6 or 3.0
97    try:
98        from ordereddict import OrderedDict
99    except Exception:
100        def OrderedDict(*args):
101            raise NotSupportedError('OrderedDict is not supported')
102
[27]103
[340]104### Module Constants
105
[463]106# compliant with DB API 2.0
[27]107apilevel = '2.0'
108
109# module may be shared, but not connections
110threadsafety = 1
111
112# this module use extended python format codes
113paramstyle = 'pyformat'
114
[772]115# shortcut methods have been excluded from DB API 2 and
116# are not recommended by the DB SIG, but they can be handy
117shortcutmethods = 1
[205]118
[463]119
[340]120### Internal Types Handling
121
[373]122def decimal_type(decimal_type=None):
123    """Get or set global type to be used for decimal values."""
124    global Decimal
125    if decimal_type is not None:
[423]126        _cast['numeric'] = Decimal = decimal_type
[373]127    return Decimal
128
129
[340]130def _cast_bool(value):
[423]131    return value[:1] in ('t', 'T')
[340]132
133
134def _cast_money(value):
[346]135    return Decimal(''.join(filter(
136        lambda v: v in '0123456789.-', value)))
[340]137
138
[774]139_cast = {'bool': _cast_bool, 'bytea': unescape_bytea,
[346]140    'int2': int, 'int4': int, 'serial': int,
[774]141    'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
142    'oid': long, 'oid8': long,
143    'float4': float, 'float8': float,
[346]144    'numeric': Decimal, 'money': _cast_money}
[340]145
146
[434]147def _db_error(msg, cls=DatabaseError):
[740]148    """Return DatabaseError with empty sqlstate attribute."""
[434]149    error = cls(msg)
150    error.sqlstate = None
151    return error
152
153
154def _op_error(msg):
[740]155    """Return OperationalError."""
[434]156    return _db_error(msg, OperationalError)
157
158
[784]159TypeInfo = namedtuple('TypeInfo',
160    ['oid', 'name', 'len', 'type', 'category', 'delim', 'relid'])
161
[788]162ColumnInfo = namedtuple('ColumnInfo', ['name', 'type'])
[784]163
[788]164
[679]165class TypeCache(dict):
[784]166    """Cache for database types.
[27]167
[787]168    This cache maps type OIDs and names to TypeInfo tuples containing
169    important information on the associated database type.
[784]170    """
171
[346]172    def __init__(self, cnx):
173        """Initialize type cache for connection."""
[679]174        super(TypeCache, self).__init__()
[784]175        self._escape_string = cnx.escape_string
[346]176        self._src = cnx.source()
[27]177
[784]178    def __missing__(self, key):
[788]179        """Get the type info from the database if it is not cached."""
[784]180        if isinstance(key, int):
[788]181            oid = key
[784]182        else:
[788]183            if not '.' in key and not '"' in key:
184                key = '"%s"' % key
185            oid = "'%s'::regtype" % self._escape_string(key)
186        self._src.execute("SELECT oid, typname,"
187             " typlen, typtype, typcategory, typdelim, typrelid"
188            " FROM pg_type WHERE oid=%s" % oid)
[787]189        res = self._src.fetch(1)
190        if not res:
191            raise KeyError('Type %s could not be found' % key)
192        res = list(res[0])
[784]193        res[0] = int(res[0])
194        res[2] = int(res[2])
195        res[6] = int(res[6])
196        res = TypeInfo(*res)
197        self[res.oid] = self[res.name] = res
198        return res
199
[788]200    def columns(self, key):
201        """Get the names and types of the columns of composite types."""
202        typ = self[key]
203        if typ.type != 'c' or not typ.relid:
204            return []  # this type is not composite
205        self._src.execute("SELECT attname, atttypid"
206            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
207            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
208        return [ColumnInfo(name, int(oid))
209            for name, oid in self._src.fetch(-1)]
210
[600]211    @staticmethod
[346]212    def typecast(typ, value):
[781]213        """Cast value according to database type."""
[346]214        if value is None:
215            # for NULL values, no typecast is necessary
216            return None
217        cast = _cast.get(typ)
218        if cast is None:
[781]219            if typ.startswith('_'):
220                # cast as an array type
221                cast = _cast.get(typ[1:])
222                return cast_array(value, cast)
[346]223            # no typecast available or necessary
224            return value
225        else:
226            return cast(value)
[27]227
228
[781]229_re_array_escape = regex(r'(["\\])')
230_re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
231
[784]232
[365]233class _quotedict(dict):
234    """Dictionary with auto quoting of its items.
[342]235
[365]236    The quote attribute must be set to the desired quote function.
237    """
[340]238
[365]239    def __getitem__(self, key):
240        return self.quote(super(_quotedict, self).__getitem__(key))
[340]241
242
243### Cursor Object
244
[679]245class Cursor(object):
[463]246    """Cursor object."""
[27]247
[346]248    def __init__(self, dbcnx):
249        """Create a cursor object for the database connection."""
[372]250        self.connection = self._dbcnx = dbcnx
[346]251        self._cnx = dbcnx._cnx
[787]252        self.type_cache = dbcnx.type_cache
[346]253        self._src = self._cnx.source()
[683]254        # the official attribute for describing the result columns
[786]255        self._description = None
[683]256        if self.row_factory is Cursor.row_factory:
257            # the row factory needs to be determined dynamically
258            self.row_factory = None
259        else:
260            self.build_row_factory = None
[346]261        self.rowcount = -1
262        self.arraysize = 1
263        self.lastrowid = None
[27]264
[374]265    def __iter__(self):
[740]266        """Make cursor compatible to the iteration protocol."""
[374]267        return self
268
[438]269    def __enter__(self):
270        """Enter the runtime context for the cursor object."""
271        return self
272
273    def __exit__(self, et, ev, tb):
274        """Exit the runtime context for the cursor object."""
275        self.close()
276
[365]277    def _quote(self, val):
278        """Quote value depending on its type."""
[781]279        if val is None:
280            return 'NULL'
[774]281        if isinstance(val, (datetime, date, time, timedelta, Json)):
[365]282            val = str(val)
[629]283        if isinstance(val, basestring):
[401]284            if isinstance(val, Binary):
[428]285                val = self._cnx.escape_bytea(val)
[629]286                if bytes is not str:  # Python >= 3.0
287                    val = val.decode('ascii')
[401]288            else:
[428]289                val = self._cnx.escape_string(val)
[781]290            return "'%s'" % val
291        if isinstance(val, float):
[423]292            if isinf(val):
[622]293                return "'-Infinity'" if val < 0 else "'Infinity'"
[781]294            if isnan(val):
[423]295                return "'NaN'"
[781]296            return val
297        if isinstance(val, (int, long, Decimal)):
298            return val
299        if isinstance(val, list):
300            return "'%s'" % self._quote_array(val)
301        if isinstance(val, tuple):
[772]302            q = self._quote
[781]303            return 'ROW(%s)' % ','.join(str(q(v)) for v in val)
304        try:
305            return val.__pg_repr__()
306        except AttributeError:
[377]307            raise InterfaceError(
308                'do not know how to handle type %s' % type(val))
[365]309
[781]310    def _quote_array(self, val):
311        """Quote value as a literal constant for an array."""
312        # We could also cast to an array constructor here, but that is more
313        # verbose and you need to know the base type to build emtpy arrays.
314        if isinstance(val, list):
315            return '{%s}' % ','.join(self._quote_array(v) for v in val)
316        if val is None:
317            return 'null'
318        if isinstance(val, (int, long, float)):
319            return str(val)
320        if isinstance(val, bool):
321            return 't' if val else 'f'
322        if isinstance(val, basestring):
323            if not val:
324                return '""'
325            if _re_array_quote.search(val):
326                return '"%s"' % _re_array_escape.sub(r'\\\1', val)
327            return val
328        try:
329            return val.__pg_repr__()
330        except AttributeError:
331            raise InterfaceError(
332                'do not know how to handle type %s' % type(val))
333
[692]334    def _quoteparams(self, string, parameters):
[365]335        """Quote parameters.
336
337        This function works for both mappings and sequences.
338        """
[692]339        if isinstance(parameters, dict):
340            parameters = _quotedict(parameters)
341            parameters.quote = self._quote
[365]342        else:
[692]343            parameters = tuple(map(self._quote, parameters))
344        return string % parameters
[365]345
[784]346    def _make_description(self, info):
347        """Make the description tuple for the given field info."""
348        name, typ, size, mod = info[1:]
[787]349        type_info = self.type_cache[typ]
[784]350        type_code = type_info.name
351        if mod > 0:
352            mod -= 4
353        if type_code == 'numeric':
354            precision, scale = mod >> 16, mod & 0xffff
355            size = precision
356        else:
357            if not size:
358                size = type_info.size
359            if size == -1:
360                size = mod
361            precision = scale = None
362        return CursorDescription(name, type_code,
363            None, size, precision, scale, None)
364
[786]365    @property
366    def description(self):
367        """Read-only attribute describing the result columns."""
368        descr = self._description
369        if self._description is True:
370            make = self._make_description
371            descr = [make(info) for info in self._src.listinfo()]
372            self._description = descr
373        return descr
374
375    @property
376    def colnames(self):
377        """Unofficial convenience method for getting the column names."""
378        return [d[0] for d in self.description]
379
380    @property
381    def coltypes(self):
382        """Unofficial convenience method for getting the column types."""
383        return [d[1] for d in self.description]
384
[346]385    def close(self):
386        """Close the cursor object."""
387        self._src.close()
[786]388        self._description = None
[346]389        self.rowcount = -1
390        self.lastrowid = None
[27]391
[692]392    def execute(self, operation, parameters=None):
[346]393        """Prepare and execute a database operation (query or command)."""
[774]394        # The parameters may also be specified as list of tuples to e.g.
395        # insert multiple rows in a single operation, but this kind of
396        # usage is deprecated.  We make several plausibility checks because
397        # tuples can also be passed with the meaning of ROW constructors.
398        if (parameters and isinstance(parameters, list)
399                and len(parameters) > 1
400                and all(isinstance(p, tuple) for p in parameters)
401                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
[692]402            return self.executemany(operation, parameters)
[346]403        else:
404            # not a list of tuples
[692]405            return self.executemany(operation, [parameters])
[27]406
[692]407    def executemany(self, operation, seq_of_parameters):
[346]408        """Prepare operation and execute it against a parameter sequence."""
[692]409        if not seq_of_parameters:
[346]410            # don't do anything without parameters
411            return
[786]412        self._description = None
[346]413        self.rowcount = -1
414        # first try to execute all queries
[683]415        rowcount = 0
[346]416        sql = "BEGIN"
417        try:
418            if not self._dbcnx._tnx:
419                try:
420                    self._cnx.source().execute(sql)
[434]421                except DatabaseError:
[772]422                    raise  # database provides error message
423                except Exception as err:
[434]424                    raise _op_error("can't start transaction")
[346]425                self._dbcnx._tnx = True
[692]426            for parameters in seq_of_parameters:
[774]427                sql = operation
[692]428                if parameters:
[774]429                    sql = self._quoteparams(sql, parameters)
[346]430                rows = self._src.execute(sql)
[436]431                if rows:  # true if not DML
[683]432                    rowcount += rows
[346]433                else:
434                    self.rowcount = -1
[434]435        except DatabaseError:
[772]436            raise  # database provides error message
[523]437        except Error as err:
[772]438            raise _db_error(
439                "error in '%s': '%s' " % (sql, err), InterfaceError)
[523]440        except Exception as err:
[434]441            raise _op_error("internal error in '%s': %s" % (sql, err))
[346]442        # then initialize result raw count and description
443        if self._src.resulttype == RESULT_DQL:
[786]444            self._description = True  # fetch on demand
[346]445            self.rowcount = self._src.ntuples
[398]446            self.lastrowid = None
[683]447            if self.build_row_factory:
448                self.row_factory = self.build_row_factory()
[346]449        else:
[683]450            self.rowcount = rowcount
[346]451            self.lastrowid = self._src.oidstatus()
[398]452        # return the cursor object, so you can write statements such as
453        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
454        return self
[27]455
[346]456    def fetchone(self):
457        """Fetch the next row of a query result set."""
458        res = self.fetchmany(1, False)
459        try:
460            return res[0]
461        except IndexError:
462            return None
[27]463
[346]464    def fetchall(self):
465        """Fetch all (remaining) rows of a query result."""
466        return self.fetchmany(-1, False)
[27]467
[346]468    def fetchmany(self, size=None, keep=False):
469        """Fetch the next set of rows of a query result.
[343]470
[346]471        The number of rows to fetch per call is specified by the
472        size parameter. If it is not given, the cursor's arraysize
473        determines the number of rows to be fetched. If you set
474        the keep parameter to true, this is kept as new arraysize.
475        """
476        if size is None:
477            size = self.arraysize
478        if keep:
479            self.arraysize = size
480        try:
481            result = self._src.fetch(size)
[434]482        except DatabaseError:
483            raise
[523]484        except Error as err:
[434]485            raise _db_error(str(err))
[787]486        typecast = self.type_cache.typecast
[682]487        return [self.row_factory([typecast(typ, value)
488            for typ, value in zip(self.coltypes, row)]) for row in result]
[27]489
[692]490    def callproc(self, procname, parameters=None):
491        """Call a stored database procedure with the given name.
492
493        The sequence of parameters must contain one entry for each input
494        argument that the procedure expects. The result of the call is the
495        same as this input sequence; replacement of output and input/output
496        parameters in the return value is currently not supported.
497
498        The procedure may also provide a result set as output. These can be
499        requested through the standard fetch methods of the cursor.
500        """
501        n = parameters and len(parameters) or 0
502        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
503        self.execute(query, parameters)
504        return parameters
505
[693]506    def copy_from(self, stream, table,
507            format=None, sep=None, null=None, size=None, columns=None):
508        """Copy data from an input stream to the specified table.
509
510        The input stream can be a file-like object with a read() method or
511        it can also be an iterable returning a row or multiple rows of input
512        on each iteration.
513
514        The format must be text, csv or binary. The sep option sets the
515        column separator (delimiter) used in the non binary formats.
516        The null option sets the textual representation of NULL in the input.
517
518        The size option sets the size of the buffer used when reading data
519        from file-like objects.
520
521        The copy operation can be restricted to a subset of columns. If no
522        columns are specified, all of them will be copied.
523        """
524        binary_format = format == 'binary'
525        try:
526            read = stream.read
527        except AttributeError:
528            if size:
529                raise ValueError("size must only be set for file-like objects")
530            if binary_format:
531                input_type = bytes
532                type_name = 'byte strings'
533            else:
534                input_type = basestring
535                type_name = 'strings'
536
537            if isinstance(stream, basestring):
538                if not isinstance(stream, input_type):
[740]539                    raise ValueError("The input must be %s" % type_name)
[693]540                if not binary_format:
541                    if isinstance(stream, str):
542                        if not stream.endswith('\n'):
543                            stream += '\n'
544                    else:
545                        if not stream.endswith(b'\n'):
546                            stream += b'\n'
547
548                def chunks():
549                    yield stream
550
551            elif isinstance(stream, Iterable):
552
553                def chunks():
554                    for chunk in stream:
555                        if not isinstance(chunk, input_type):
556                            raise ValueError(
[740]557                                "Input stream must consist of %s" % type_name)
[693]558                        if isinstance(chunk, str):
559                            if not chunk.endswith('\n'):
560                                chunk += '\n'
561                        else:
562                            if not chunk.endswith(b'\n'):
563                                chunk += b'\n'
564                        yield chunk
565
566            else:
[740]567                raise TypeError("Need an input stream to copy from")
[693]568        else:
569            if size is None:
570                size = 8192
[772]571            elif not isinstance(size, int):
572                raise TypeError("The size option must be an integer")
[693]573            if size > 0:
574
575                def chunks():
576                    while True:
577                        buffer = read(size)
578                        yield buffer
579                        if not buffer or len(buffer) < size:
580                            break
581
582            else:
583
584                def chunks():
585                    yield read()
586
587        if not table or not isinstance(table, basestring):
[740]588            raise TypeError("Need a table to copy to")
[693]589        if table.lower().startswith('select'):
[740]590                raise ValueError("Must specify a table, not a query")
[693]591        else:
592            table = '"%s"' % (table,)
593        operation = ['copy %s' % (table,)]
594        options = []
595        params = []
596        if format is not None:
597            if not isinstance(format, basestring):
[740]598                raise TypeError("format option must be be a string")
[693]599            if format not in ('text', 'csv', 'binary'):
600                raise ValueError("invalid format")
601            options.append('format %s' % (format,))
602        if sep is not None:
603            if not isinstance(sep, basestring):
[740]604                raise TypeError("sep option must be a string")
[693]605            if format == 'binary':
606                raise ValueError("sep is not allowed with binary format")
607            if len(sep) != 1:
608                raise ValueError("sep must be a single one-byte character")
609            options.append('delimiter %s')
610            params.append(sep)
611        if null is not None:
612            if not isinstance(null, basestring):
[740]613                raise TypeError("null option must be a string")
[693]614            options.append('null %s')
615            params.append(null)
616        if columns:
617            if not isinstance(columns, basestring):
618                columns = ','.join('"%s"' % (col,) for col in columns)
619            operation.append('(%s)' % (columns,))
620        operation.append("from stdin")
621        if options:
622            operation.append('(%s)' % ','.join(options))
623        operation = ' '.join(operation)
624
625        putdata = self._src.putdata
626        self.execute(operation, params)
627
628        try:
629            for chunk in chunks():
630                putdata(chunk)
631        except BaseException as error:
632            self.rowcount = -1
633            # the following call will re-raise the error
634            putdata(error)
635        else:
636            self.rowcount = putdata(None)
637
638        # return the cursor object, so you can chain operations
639        return self
640
641    def copy_to(self, stream, table,
642            format=None, sep=None, null=None, decode=None, columns=None):
643        """Copy data from the specified table to an output stream.
644
645        The output stream can be a file-like object with a write() method or
646        it can also be None, in which case the method will return a generator
647        yielding a row on each iteration.
648
649        Output will be returned as byte strings unless you set decode to true.
650
651        Note that you can also use a select query instead of the table name.
652
653        The format must be text, csv or binary. The sep option sets the
654        column separator (delimiter) used in the non binary formats.
655        The null option sets the textual representation of NULL in the output.
656
657        The copy operation can be restricted to a subset of columns. If no
658        columns are specified, all of them will be copied.
659        """
660        binary_format = format == 'binary'
661        if stream is not None:
662            try:
663                write = stream.write
664            except AttributeError:
665                raise TypeError("need an output stream to copy to")
666        if not table or not isinstance(table, basestring):
667            raise TypeError("need a table to copy to")
668        if table.lower().startswith('select'):
669            if columns:
670                raise ValueError("columns must be specified in the query")
671            table = '(%s)' % (table,)
672        else:
673            table = '"%s"' % (table,)
674        operation = ['copy %s' % (table,)]
675        options = []
676        params = []
677        if format is not None:
678            if not isinstance(format, basestring):
[740]679                raise TypeError("format option must be a string")
[693]680            if format not in ('text', 'csv', 'binary'):
681                raise ValueError("invalid format")
682            options.append('format %s' % (format,))
683        if sep is not None:
684            if not isinstance(sep, basestring):
[740]685                raise TypeError("sep option must be a string")
[693]686            if binary_format:
687                raise ValueError("sep is not allowed with binary format")
688            if len(sep) != 1:
689                raise ValueError("sep must be a single one-byte character")
690            options.append('delimiter %s')
691            params.append(sep)
692        if null is not None:
693            if not isinstance(null, basestring):
[740]694                raise TypeError("null option must be a string")
[693]695            options.append('null %s')
696            params.append(null)
697        if decode is None:
698            if format == 'binary':
699                decode = False
700            else:
701                decode = str is unicode
702        else:
703            if not isinstance(decode, (int, bool)):
[740]704                raise TypeError("decode option must be a boolean")
[693]705            if decode and binary_format:
706                raise ValueError("decode is not allowed with binary format")
707        if columns:
708            if not isinstance(columns, basestring):
709                columns = ','.join('"%s"' % (col,) for col in columns)
710            operation.append('(%s)' % (columns,))
711
712        operation.append("to stdout")
713        if options:
714            operation.append('(%s)' % ','.join(options))
715        operation = ' '.join(operation)
716
717        getdata = self._src.getdata
718        self.execute(operation, params)
719
720        def copy():
[697]721            self.rowcount = 0
[693]722            while True:
723                row = getdata(decode)
724                if isinstance(row, int):
725                    if self.rowcount != row:
726                        self.rowcount = row
727                    break
728                self.rowcount += 1
729                yield row
730
731        if stream is None:
732            # no input stream, return the generator
733            return copy()
734
735        # write the rows to the file-like input stream
736        for row in copy():
737            write(row)
738
739        # return the cursor object, so you can chain operations
740        return self
741
[601]742    def __next__(self):
[374]743        """Return the next row (support for the iteration protocol)."""
744        res = self.fetchone()
745        if res is None:
746            raise StopIteration
747        return res
748
[601]749    # Note that since Python 2.6 the iterator protocol uses __next()__
750    # instead of next(), we keep it only for backward compatibility of pgdb.
751    next = __next__
752
[600]753    @staticmethod
[346]754    def nextset():
755        """Not supported."""
756        raise NotSupportedError("nextset() is not supported")
[122]757
[678]758    @staticmethod
[346]759    def setinputsizes(sizes):
760        """Not supported."""
[463]761        pass  # unsupported, but silently passed
[27]762
[678]763    @staticmethod
[346]764    def setoutputsize(size, column=0):
765        """Not supported."""
[463]766        pass  # unsupported, but silently passed
[27]767
[683]768    @staticmethod
769    def row_factory(row):
770        """Process rows before they are returned.
[42]771
[683]772        You can overwrite this statically with a custom row factory, or
773        you can build a row factory dynamically with build_row_factory().
[681]774
[683]775        For example, you can create a Cursor class that returns rows as
776        Python dictionaries like this:
[681]777
[683]778            class DictCursor(pgdb.Cursor):
[682]779
[683]780                def row_factory(self, row):
781                    return {desc[0]: value
782                        for desc, value in zip(self.description, row)}
[682]783
[683]784            cur = DictCursor(con)  # get one DictCursor instance or
785            con.cursor_type = DictCursor  # always use DictCursor instances
786        """
787        raise NotImplementedError
[682]788
[683]789    def build_row_factory(self):
790        """Build a row factory based on the current description.
[682]791
[683]792        This implementation builds a row factory for creating named tuples.
793        You can overwrite this method if you want to dynamically create
794        different row factories whenever the column description changes.
795        """
796        colnames = self.colnames
797        if colnames:
[682]798            try:
799                try:
[683]800                    return namedtuple('Row', colnames, rename=True)._make
[682]801                except TypeError:  # Python 2.6 and 3.0 do not support rename
[683]802                    colnames = [v if v.isalnum() else 'column_%d' % n
803                             for n, v in enumerate(colnames)]
804                    return namedtuple('Row', colnames)._make
805            except ValueError:  # there is still a problem with the field names
806                colnames = ['column_%d' % n for n in range(len(colnames))]
807                return namedtuple('Row', colnames)._make
[682]808
809
[683]810CursorDescription = namedtuple('CursorDescription',
811    ['name', 'type_code', 'display_size', 'internal_size',
812     'precision', 'scale', 'null_ok'])
813
814
[340]815### Connection Objects
[184]816
[679]817class Connection(object):
[463]818    """Connection object."""
[27]819
[347]820    # expose the exceptions as attributes on the connection object
[373]821    Error = Error
[347]822    Warning = Warning
823    InterfaceError = InterfaceError
824    DatabaseError = DatabaseError
[373]825    InternalError = InternalError
[347]826    OperationalError = OperationalError
[373]827    ProgrammingError = ProgrammingError
[347]828    IntegrityError = IntegrityError
[373]829    DataError = DataError
[347]830    NotSupportedError = NotSupportedError
831
[346]832    def __init__(self, cnx):
833        """Create a database connection object."""
[436]834        self._cnx = cnx  # connection
835        self._tnx = False  # transaction state
[787]836        self.type_cache = TypeCache(cnx)
[682]837        self.cursor_type = Cursor
[346]838        try:
839            self._cnx.source()
[359]840        except Exception:
[434]841            raise _op_error("invalid connection")
[27]842
[438]843    def __enter__(self):
[463]844        """Enter the runtime context for the connection object.
845
846        The runtime context can be used for running transactions.
847        """
[438]848        return self
849
850    def __exit__(self, et, ev, tb):
[463]851        """Exit the runtime context for the connection object.
[438]852
[463]853        This does not close the connection, but it ends a transaction.
854        """
855        if et is None and ev is None and tb is None:
856            self.commit()
857        else:
858            self.rollback()
859
[346]860    def close(self):
861        """Close the connection object."""
862        if self._cnx:
[437]863            if self._tnx:
864                try:
865                    self.rollback()
866                except DatabaseError:
867                    pass
[346]868            self._cnx.close()
869            self._cnx = None
870        else:
[434]871            raise _op_error("connection has been closed")
[27]872
[346]873    def commit(self):
874        """Commit any pending transaction to the database."""
875        if self._cnx:
876            if self._tnx:
877                self._tnx = False
878                try:
879                    self._cnx.source().execute("COMMIT")
[434]880                except DatabaseError:
881                    raise
[359]882                except Exception:
[434]883                    raise _op_error("can't commit")
[346]884        else:
[434]885            raise _op_error("connection has been closed")
[27]886
[346]887    def rollback(self):
888        """Roll back to the start of any pending transaction."""
889        if self._cnx:
890            if self._tnx:
891                self._tnx = False
892                try:
893                    self._cnx.source().execute("ROLLBACK")
[434]894                except DatabaseError:
895                    raise
[359]896                except Exception:
[434]897                    raise _op_error("can't rollback")
[346]898        else:
[434]899            raise _op_error("connection has been closed")
[27]900
[683]901    def cursor(self):
[463]902        """Return a new cursor object using the connection."""
[346]903        if self._cnx:
904            try:
[683]905                return self.cursor_type(self)
[359]906            except Exception:
[434]907                raise _op_error("invalid connection")
[346]908        else:
[434]909            raise _op_error("connection has been closed")
[27]910
[463]911    if shortcutmethods:  # otherwise do not implement and document this
[27]912
[463]913        def execute(self, operation, params=None):
914            """Shortcut method to run an operation on an implicit cursor."""
915            cursor = self.cursor()
916            cursor.execute(operation, params)
917            return cursor
918
919        def executemany(self, operation, param_seq):
920            """Shortcut method to run an operation against a sequence."""
921            cursor = self.cursor()
922            cursor.executemany(operation, param_seq)
923            return cursor
924
925
[307]926### Module Interface
927
[175]928_connect_ = connect
[343]929
[307]930def connect(dsn=None,
[346]931        user=None, password=None,
932        host=None, database=None):
[740]933    """Connect to a database."""
[346]934    # first get params from DSN
935    dbport = -1
936    dbhost = ""
937    dbbase = ""
938    dbuser = ""
939    dbpasswd = ""
940    dbopt = ""
941    try:
942        params = dsn.split(":")
943        dbhost = params[0]
944        dbbase = params[1]
945        dbuser = params[2]
946        dbpasswd = params[3]
947        dbopt = params[4]
[367]948    except (AttributeError, IndexError, TypeError):
[346]949        pass
[27]950
[346]951    # override if necessary
952    if user is not None:
953        dbuser = user
954    if password is not None:
955        dbpasswd = password
956    if database is not None:
957        dbbase = database
958    if host is not None:
959        try:
960            params = host.split(":")
961            dbhost = params[0]
962            dbport = int(params[1])
[367]963        except (AttributeError, IndexError, TypeError, ValueError):
[346]964            pass
[27]965
[346]966    # empty host is localhost
967    if dbhost == "":
968        dbhost = None
969    if dbuser == "":
970        dbuser = None
[27]971
[346]972    # open the connection
[680]973    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
[679]974    return Connection(cnx)
[27]975
976
[307]977### Types Handling
[27]978
[679]979class Type(frozenset):
[346]980    """Type class for a couple of PostgreSQL data types.
[27]981
[346]982    PostgreSQL is object-oriented: types are dynamic.
983    We must thus use type names as internal type codes.
984    """
[307]985
[599]986    def __new__(cls, values):
987        if isinstance(values, basestring):
988            values = values.split()
[679]989        return super(Type, cls).__new__(cls, values)
[27]990
[346]991    def __eq__(self, other):
992        if isinstance(other, basestring):
[781]993            if other.startswith('_'):
994                other = other[1:]
[346]995            return other in self
996        else:
[679]997            return super(Type, self).__eq__(other)
[27]998
[346]999    def __ne__(self, other):
1000        if isinstance(other, basestring):
[781]1001            if other.startswith('_'):
1002                other = other[1:]
[346]1003            return other not in self
1004        else:
[679]1005            return super(Type, self).__ne__(other)
[307]1006
[342]1007
[781]1008class ArrayType:
1009    """Type class for PostgreSQL array types."""
1010
1011    def __eq__(self, other):
1012        if isinstance(other, basestring):
1013            return other.startswith('_')
1014        else:
1015            return isinstance(other, ArrayType)
1016
1017    def __ne__(self, other):
1018        if isinstance(other, basestring):
1019            return not other.startswith('_')
1020        else:
1021            return not isinstance(other, ArrayType)
1022
1023
[302]1024# Mandatory type objects defined by DB-API 2 specs:
[27]1025
[679]1026STRING = Type('char bpchar name text varchar')
1027BINARY = Type('bytea')
1028NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
[781]1029DATETIME = Type('date time timetz timestamp timestamptz interval'
1030    ' abstime reltime')  # these are very old
1031ROWID = Type('oid')
[302]1032
[307]1033
[302]1034# Additional type objects (more specific):
1035
[679]1036BOOL = Type('bool')
1037SMALLINT = Type('int2')
1038INTEGER = Type('int2 int4 int8 serial')
1039LONG = Type('int8')
1040FLOAT = Type('float4 float8')
1041NUMERIC = Type('numeric')
1042MONEY = Type('money')
1043DATE = Type('date')
1044TIME = Type('time timetz')
[781]1045TIMESTAMP = Type('timestamp timestamptz')
1046INTERVAL = Type('interval')
[774]1047JSON = Type('json jsonb')
[27]1048
[781]1049# Type object for arrays (also equate to their base types):
[307]1050
[781]1051ARRAY = ArrayType()
1052
1053
[302]1054# Mandatory type helpers defined by DB-API 2 specs:
[27]1055
1056def Date(year, month, day):
[346]1057    """Construct an object holding a date value."""
[423]1058    return date(year, month, day)
[27]1059
[436]1060
[423]1061def Time(hour, minute=0, second=0, microsecond=0):
[346]1062    """Construct an object holding a time value."""
[423]1063    return time(hour, minute, second, microsecond)
[27]1064
[436]1065
[423]1066def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
[774]1067    """Construct an object holding a time stamp value."""
[423]1068    return datetime(year, month, day, hour, minute, second, microsecond)
[27]1069
[436]1070
[27]1071def DateFromTicks(ticks):
[346]1072    """Construct an object holding a date value from the given ticks value."""
[423]1073    return Date(*localtime(ticks)[:3])
[27]1074
[436]1075
[27]1076def TimeFromTicks(ticks):
[774]1077    """Construct an object holding a time value from the given ticks value."""
[423]1078    return Time(*localtime(ticks)[3:6])
[27]1079
[436]1080
[27]1081def TimestampFromTicks(ticks):
[774]1082    """Construct an object holding a time stamp from the given ticks value."""
[423]1083    return Timestamp(*localtime(ticks)[:6])
[222]1084
[436]1085
[629]1086class Binary(bytes):
[774]1087    """Construct an object capable of holding a binary (long) string value."""
[230]1088
[307]1089
[774]1090# Additional type helpers for PyGreSQL:
1091
1092class Json:
1093    """Construct a wrapper for holding an object serializable to JSON."""
1094
1095    def __init__(self, obj, encode=None):
1096        self.obj = obj
1097        self.encode = encode or jsonencode
1098
1099    def __str__(self):
1100        obj = self.obj
1101        if isinstance(obj, basestring):
1102            return obj
1103        return self.encode(obj)
1104
1105    __pg_repr__ = __str__
1106
1107
[340]1108# If run as script, print some information:
[307]1109
[222]1110if __name__ == '__main__':
[422]1111    print('PyGreSQL version', version)
1112    print('')
1113    print(__doc__)
Note: See TracBrowser for help on using the repository browser.