source: trunk/pgdb.py @ 791

Last change on this file since 791 was 791, checked in by cito, 3 years ago

Add support for composite types

Added a fast parser for the composite type input/output syntax, which is
similar to the already existing parser for the array input/output syntax.

The pgdb module now makes use of this parser, converting in both directions
between PostgreSQL records (composite types) and Python (named) tuples.

  • Property svn:keywords set to Id
File size: 38.3 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 791 2016-01-27 23:09:18Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from re import compile as regex
76from json import loads as jsondecode, dumps as jsonencode
77
78try:
79    long
80except NameError:  # Python >= 3.0
81    long = int
82
83try:
84    unicode
85except NameError:  # Python >= 3.0
86    unicode = str
87
88try:
89    basestring
90except NameError:  # Python >= 3.0
91    basestring = (str, bytes)
92
93from collections import Iterable
94try:
95    from collections import OrderedDict
96except ImportError:  # Python 2.6 or 3.0
97    try:
98        from ordereddict import OrderedDict
99    except Exception:
100        def OrderedDict(*args):
101            raise NotSupportedError('OrderedDict is not supported')
102
103
104### Module Constants
105
106# compliant with DB API 2.0
107apilevel = '2.0'
108
109# module may be shared, but not connections
110threadsafety = 1
111
112# this module use extended python format codes
113paramstyle = 'pyformat'
114
115# shortcut methods have been excluded from DB API 2 and
116# are not recommended by the DB SIG, but they can be handy
117shortcutmethods = 1
118
119
120### Internal Types Handling
121
122def decimal_type(decimal_type=None):
123    """Get or set global type to be used for decimal values."""
124    global Decimal
125    if decimal_type is not None:
126        _cast['numeric'] = Decimal = decimal_type
127    return Decimal
128
129
130def _cast_bool(value):
131    return value[:1] in ('t', 'T')
132
133
134def _cast_money(value):
135    return Decimal(''.join(filter(
136        lambda v: v in '0123456789.-', value)))
137
138
139_cast = {'char': str, 'bpchar': str, 'name': str,
140    'text': str, 'varchar': str,
141    'bool': _cast_bool, 'bytea': unescape_bytea,
142    'int2': int, 'int4': int, 'serial': int,
143    'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
144    'oid': long, 'oid8': long,
145    'float4': float, 'float8': float,
146    'numeric': Decimal, 'money': _cast_money,
147    'record': cast_record}
148
149
150def _db_error(msg, cls=DatabaseError):
151    """Return DatabaseError with empty sqlstate attribute."""
152    error = cls(msg)
153    error.sqlstate = None
154    return error
155
156
157def _op_error(msg):
158    """Return OperationalError."""
159    return _db_error(msg, OperationalError)
160
161
162TypeInfo = namedtuple('TypeInfo',
163    ['oid', 'name', 'len', 'type', 'category', 'delim', 'relid'])
164
165ColumnInfo = namedtuple('ColumnInfo', ['name', 'type'])
166
167
168class TypeCache(dict):
169    """Cache for database types.
170
171    This cache maps type OIDs and names to TypeInfo tuples containing
172    important information on the associated database type.
173    """
174
175    def __init__(self, cnx):
176        """Initialize type cache for connection."""
177        super(TypeCache, self).__init__()
178        self._escape_string = cnx.escape_string
179        self._src = cnx.source()
180
181    def __missing__(self, key):
182        """Get the type info from the database if it is not cached."""
183        if isinstance(key, int):
184            oid = key
185        else:
186            if not '.' in key and not '"' in key:
187                key = '"%s"' % key
188            oid = "'%s'::regtype" % self._escape_string(key)
189        try:
190            self._src.execute("SELECT oid, typname,"
191                 " typlen, typtype, typcategory, typdelim, typrelid"
192                " FROM pg_type WHERE oid=%s" % oid)
193        except ProgrammingError:
194            res = None
195        else:
196            res = self._src.fetch(1)
197        if not res:
198            raise KeyError('Type %s could not be found' % key)
199        res = list(res[0])
200        res[0] = int(res[0])
201        res[2] = int(res[2])
202        res[6] = int(res[6])
203        res = TypeInfo(*res)
204        self[res.oid] = self[res.name] = res
205        return res
206
207    def get(self, key, default=None):
208        """Get the type even if it is not cached."""
209        try:
210            return self[key]
211        except KeyError:
212            return default
213
214    def columns(self, key):
215        """Get the names and types of the columns of composite types."""
216        try:
217            typ = self[key]
218        except KeyError:
219            return None  # this type is not known
220        if typ.type != 'c' or not typ.relid:
221            return None  # this type is not composite
222        self._src.execute("SELECT attname, atttypid"
223            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
224            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
225        return [ColumnInfo(name, int(oid))
226            for name, oid in self._src.fetch(-1)]
227
228    def typecast(self, typ, value):
229        """Cast value according to database type."""
230        if value is None:
231            # for NULL values, no typecast is necessary
232            return None
233        cast = _cast.get(typ)
234        if cast is str:
235            return value  # no typecast necessary
236        if cast is None:
237            if typ.startswith('_'):
238                # cast as an array type
239                cast = _cast.get(typ[1:])
240                return cast_array(value, cast)
241            # check whether this is a composite type
242            cols = self.columns(typ)
243            if cols:
244                getcast = self.getcast
245                cast = [getcast(col.type) for col in cols]
246                value = cast_record(value, cast)
247                fields = [col.name for col in cols]
248                record = namedtuple(typ, fields)
249                return record(*value)
250            return value  # no typecast available or necessary
251        else:
252            return cast(value)
253
254    def getcast(self, key):
255        """Get a cast function for the given database type."""
256        if isinstance(key, int):
257            try:
258                typ = self[key].name
259            except KeyError:
260                return None
261        else:
262            typ = key
263        typecast = self.typecast
264        return lambda value: typecast(typ, value)
265
266
267_re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
268_re_record_quote = regex(r'[(,"\\]')
269_re_array_escape = _re_record_escape = regex(r'(["\\])')
270
271
272class _quotedict(dict):
273    """Dictionary with auto quoting of its items.
274
275    The quote attribute must be set to the desired quote function.
276    """
277
278    def __getitem__(self, key):
279        return self.quote(super(_quotedict, self).__getitem__(key))
280
281
282### Cursor Object
283
284class Cursor(object):
285    """Cursor object."""
286
287    def __init__(self, dbcnx):
288        """Create a cursor object for the database connection."""
289        self.connection = self._dbcnx = dbcnx
290        self._cnx = dbcnx._cnx
291        self.type_cache = dbcnx.type_cache
292        self._src = self._cnx.source()
293        # the official attribute for describing the result columns
294        self._description = None
295        if self.row_factory is Cursor.row_factory:
296            # the row factory needs to be determined dynamically
297            self.row_factory = None
298        else:
299            self.build_row_factory = None
300        self.rowcount = -1
301        self.arraysize = 1
302        self.lastrowid = None
303
304    def __iter__(self):
305        """Make cursor compatible to the iteration protocol."""
306        return self
307
308    def __enter__(self):
309        """Enter the runtime context for the cursor object."""
310        return self
311
312    def __exit__(self, et, ev, tb):
313        """Exit the runtime context for the cursor object."""
314        self.close()
315
316    def _quote(self, val):
317        """Quote value depending on its type."""
318        if val is None:
319            return 'NULL'
320        if isinstance(val, (datetime, date, time, timedelta, Json)):
321            val = str(val)
322        if isinstance(val, basestring):
323            if isinstance(val, Binary):
324                val = self._cnx.escape_bytea(val)
325                if bytes is not str:  # Python >= 3.0
326                    val = val.decode('ascii')
327            else:
328                val = self._cnx.escape_string(val)
329            return "'%s'" % val
330        if isinstance(val, float):
331            if isinf(val):
332                return "'-Infinity'" if val < 0 else "'Infinity'"
333            if isnan(val):
334                return "'NaN'"
335            return val
336        if isinstance(val, (int, long, Decimal)):
337            return val
338        if isinstance(val, list):
339            return "'%s'" % self._quote_array(val)
340        if isinstance(val, tuple):
341            return "'%s'" % self._quote_record(val)
342        try:
343            return val.__pg_repr__()
344        except AttributeError:
345            raise InterfaceError(
346                'do not know how to handle type %s' % type(val))
347
348    def _quote_array(self, val):
349        """Quote value as a literal constant for an array."""
350        q = self._quote_array_element
351        return '{%s}' % ','.join(q(v) for v in val)
352
353    def _quote_array_element(self, val):
354        """Quote value using the output syntax for arrays."""
355        if isinstance(val, list):
356            return self._quote_array(val)
357        if val is None:
358            return 'null'
359        if isinstance(val, (int, long, float)):
360            return str(val)
361        if isinstance(val, bool):
362            return 't' if val else 'f'
363        if isinstance(val, tuple):
364            val = self._quote_record(val)
365        if isinstance(val, basestring):
366            if not val:
367                return '""'
368            if _re_array_quote.search(val):
369                return '"%s"' % _re_array_escape.sub(r'\\\1', val)
370            return val
371        raise InterfaceError(
372            'do not know how to handle base type %s' % type(val))
373
374    def _quote_record(self, val):
375        """Quote value as a literal constant for a record."""
376        q = self._quote_record_element
377        return '(%s)' % ','.join(q(v) for v in val)
378
379    def _quote_record_element(self, val):
380        """Quote value using the output syntax for records."""
381        if val is None:
382            return ''
383        if isinstance(val, (int, long, float)):
384            return str(val)
385        if isinstance(val, bool):
386            return 't' if val else 'f'
387        if isinstance(val, list):
388            val = self._quote_array(val)
389        if isinstance(val, basestring):
390            if not val:
391                return '""'
392            if _re_record_quote.search(val):
393                return '"%s"' % _re_record_escape.sub(r'\\\1', val)
394            return val
395        raise InterfaceError(
396            'do not know how to handle component type %s' % type(val))
397
398    def _quoteparams(self, string, parameters):
399        """Quote parameters.
400
401        This function works for both mappings and sequences.
402        """
403        if isinstance(parameters, dict):
404            parameters = _quotedict(parameters)
405            parameters.quote = self._quote
406        else:
407            parameters = tuple(map(self._quote, parameters))
408        return string % parameters
409
410    def _make_description(self, info):
411        """Make the description tuple for the given field info."""
412        name, typ, size, mod = info[1:]
413        type_info = self.type_cache[typ]
414        type_code = type_info.name
415        if mod > 0:
416            mod -= 4
417        if type_code == 'numeric':
418            precision, scale = mod >> 16, mod & 0xffff
419            size = precision
420        else:
421            if not size:
422                size = type_info.size
423            if size == -1:
424                size = mod
425            precision = scale = None
426        return CursorDescription(name, type_code,
427            None, size, precision, scale, None)
428
429    @property
430    def description(self):
431        """Read-only attribute describing the result columns."""
432        descr = self._description
433        if self._description is True:
434            make = self._make_description
435            descr = [make(info) for info in self._src.listinfo()]
436            self._description = descr
437        return descr
438
439    @property
440    def colnames(self):
441        """Unofficial convenience method for getting the column names."""
442        return [d[0] for d in self.description]
443
444    @property
445    def coltypes(self):
446        """Unofficial convenience method for getting the column types."""
447        return [d[1] for d in self.description]
448
449    def close(self):
450        """Close the cursor object."""
451        self._src.close()
452        self._description = None
453        self.rowcount = -1
454        self.lastrowid = None
455
456    def execute(self, operation, parameters=None):
457        """Prepare and execute a database operation (query or command)."""
458        # The parameters may also be specified as list of tuples to e.g.
459        # insert multiple rows in a single operation, but this kind of
460        # usage is deprecated.  We make several plausibility checks because
461        # tuples can also be passed with the meaning of ROW constructors.
462        if (parameters and isinstance(parameters, list)
463                and len(parameters) > 1
464                and all(isinstance(p, tuple) for p in parameters)
465                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
466            return self.executemany(operation, parameters)
467        else:
468            # not a list of tuples
469            return self.executemany(operation, [parameters])
470
471    def executemany(self, operation, seq_of_parameters):
472        """Prepare operation and execute it against a parameter sequence."""
473        if not seq_of_parameters:
474            # don't do anything without parameters
475            return
476        self._description = None
477        self.rowcount = -1
478        # first try to execute all queries
479        rowcount = 0
480        sql = "BEGIN"
481        try:
482            if not self._dbcnx._tnx:
483                try:
484                    self._cnx.source().execute(sql)
485                except DatabaseError:
486                    raise  # database provides error message
487                except Exception as err:
488                    raise _op_error("can't start transaction")
489                self._dbcnx._tnx = True
490            for parameters in seq_of_parameters:
491                sql = operation
492                if parameters:
493                    sql = self._quoteparams(sql, parameters)
494                rows = self._src.execute(sql)
495                if rows:  # true if not DML
496                    rowcount += rows
497                else:
498                    self.rowcount = -1
499        except DatabaseError:
500            raise  # database provides error message
501        except Error as err:
502            raise _db_error(
503                "error in '%s': '%s' " % (sql, err), InterfaceError)
504        except Exception as err:
505            raise _op_error("internal error in '%s': %s" % (sql, err))
506        # then initialize result raw count and description
507        if self._src.resulttype == RESULT_DQL:
508            self._description = True  # fetch on demand
509            self.rowcount = self._src.ntuples
510            self.lastrowid = None
511            if self.build_row_factory:
512                self.row_factory = self.build_row_factory()
513        else:
514            self.rowcount = rowcount
515            self.lastrowid = self._src.oidstatus()
516        # return the cursor object, so you can write statements such as
517        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
518        return self
519
520    def fetchone(self):
521        """Fetch the next row of a query result set."""
522        res = self.fetchmany(1, False)
523        try:
524            return res[0]
525        except IndexError:
526            return None
527
528    def fetchall(self):
529        """Fetch all (remaining) rows of a query result."""
530        return self.fetchmany(-1, False)
531
532    def fetchmany(self, size=None, keep=False):
533        """Fetch the next set of rows of a query result.
534
535        The number of rows to fetch per call is specified by the
536        size parameter. If it is not given, the cursor's arraysize
537        determines the number of rows to be fetched. If you set
538        the keep parameter to true, this is kept as new arraysize.
539        """
540        if size is None:
541            size = self.arraysize
542        if keep:
543            self.arraysize = size
544        try:
545            result = self._src.fetch(size)
546        except DatabaseError:
547            raise
548        except Error as err:
549            raise _db_error(str(err))
550        typecast = self.type_cache.typecast
551        return [self.row_factory([typecast(typ, value)
552            for typ, value in zip(self.coltypes, row)]) for row in result]
553
554    def callproc(self, procname, parameters=None):
555        """Call a stored database procedure with the given name.
556
557        The sequence of parameters must contain one entry for each input
558        argument that the procedure expects. The result of the call is the
559        same as this input sequence; replacement of output and input/output
560        parameters in the return value is currently not supported.
561
562        The procedure may also provide a result set as output. These can be
563        requested through the standard fetch methods of the cursor.
564        """
565        n = parameters and len(parameters) or 0
566        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
567        self.execute(query, parameters)
568        return parameters
569
570    def copy_from(self, stream, table,
571            format=None, sep=None, null=None, size=None, columns=None):
572        """Copy data from an input stream to the specified table.
573
574        The input stream can be a file-like object with a read() method or
575        it can also be an iterable returning a row or multiple rows of input
576        on each iteration.
577
578        The format must be text, csv or binary. The sep option sets the
579        column separator (delimiter) used in the non binary formats.
580        The null option sets the textual representation of NULL in the input.
581
582        The size option sets the size of the buffer used when reading data
583        from file-like objects.
584
585        The copy operation can be restricted to a subset of columns. If no
586        columns are specified, all of them will be copied.
587        """
588        binary_format = format == 'binary'
589        try:
590            read = stream.read
591        except AttributeError:
592            if size:
593                raise ValueError("size must only be set for file-like objects")
594            if binary_format:
595                input_type = bytes
596                type_name = 'byte strings'
597            else:
598                input_type = basestring
599                type_name = 'strings'
600
601            if isinstance(stream, basestring):
602                if not isinstance(stream, input_type):
603                    raise ValueError("The input must be %s" % type_name)
604                if not binary_format:
605                    if isinstance(stream, str):
606                        if not stream.endswith('\n'):
607                            stream += '\n'
608                    else:
609                        if not stream.endswith(b'\n'):
610                            stream += b'\n'
611
612                def chunks():
613                    yield stream
614
615            elif isinstance(stream, Iterable):
616
617                def chunks():
618                    for chunk in stream:
619                        if not isinstance(chunk, input_type):
620                            raise ValueError(
621                                "Input stream must consist of %s" % type_name)
622                        if isinstance(chunk, str):
623                            if not chunk.endswith('\n'):
624                                chunk += '\n'
625                        else:
626                            if not chunk.endswith(b'\n'):
627                                chunk += b'\n'
628                        yield chunk
629
630            else:
631                raise TypeError("Need an input stream to copy from")
632        else:
633            if size is None:
634                size = 8192
635            elif not isinstance(size, int):
636                raise TypeError("The size option must be an integer")
637            if size > 0:
638
639                def chunks():
640                    while True:
641                        buffer = read(size)
642                        yield buffer
643                        if not buffer or len(buffer) < size:
644                            break
645
646            else:
647
648                def chunks():
649                    yield read()
650
651        if not table or not isinstance(table, basestring):
652            raise TypeError("Need a table to copy to")
653        if table.lower().startswith('select'):
654                raise ValueError("Must specify a table, not a query")
655        else:
656            table = '"%s"' % (table,)
657        operation = ['copy %s' % (table,)]
658        options = []
659        params = []
660        if format is not None:
661            if not isinstance(format, basestring):
662                raise TypeError("format option must be be a string")
663            if format not in ('text', 'csv', 'binary'):
664                raise ValueError("invalid format")
665            options.append('format %s' % (format,))
666        if sep is not None:
667            if not isinstance(sep, basestring):
668                raise TypeError("sep option must be a string")
669            if format == 'binary':
670                raise ValueError("sep is not allowed with binary format")
671            if len(sep) != 1:
672                raise ValueError("sep must be a single one-byte character")
673            options.append('delimiter %s')
674            params.append(sep)
675        if null is not None:
676            if not isinstance(null, basestring):
677                raise TypeError("null option must be a string")
678            options.append('null %s')
679            params.append(null)
680        if columns:
681            if not isinstance(columns, basestring):
682                columns = ','.join('"%s"' % (col,) for col in columns)
683            operation.append('(%s)' % (columns,))
684        operation.append("from stdin")
685        if options:
686            operation.append('(%s)' % ','.join(options))
687        operation = ' '.join(operation)
688
689        putdata = self._src.putdata
690        self.execute(operation, params)
691
692        try:
693            for chunk in chunks():
694                putdata(chunk)
695        except BaseException as error:
696            self.rowcount = -1
697            # the following call will re-raise the error
698            putdata(error)
699        else:
700            self.rowcount = putdata(None)
701
702        # return the cursor object, so you can chain operations
703        return self
704
705    def copy_to(self, stream, table,
706            format=None, sep=None, null=None, decode=None, columns=None):
707        """Copy data from the specified table to an output stream.
708
709        The output stream can be a file-like object with a write() method or
710        it can also be None, in which case the method will return a generator
711        yielding a row on each iteration.
712
713        Output will be returned as byte strings unless you set decode to true.
714
715        Note that you can also use a select query instead of the table name.
716
717        The format must be text, csv or binary. The sep option sets the
718        column separator (delimiter) used in the non binary formats.
719        The null option sets the textual representation of NULL in the output.
720
721        The copy operation can be restricted to a subset of columns. If no
722        columns are specified, all of them will be copied.
723        """
724        binary_format = format == 'binary'
725        if stream is not None:
726            try:
727                write = stream.write
728            except AttributeError:
729                raise TypeError("need an output stream to copy to")
730        if not table or not isinstance(table, basestring):
731            raise TypeError("need a table to copy to")
732        if table.lower().startswith('select'):
733            if columns:
734                raise ValueError("columns must be specified in the query")
735            table = '(%s)' % (table,)
736        else:
737            table = '"%s"' % (table,)
738        operation = ['copy %s' % (table,)]
739        options = []
740        params = []
741        if format is not None:
742            if not isinstance(format, basestring):
743                raise TypeError("format option must be a string")
744            if format not in ('text', 'csv', 'binary'):
745                raise ValueError("invalid format")
746            options.append('format %s' % (format,))
747        if sep is not None:
748            if not isinstance(sep, basestring):
749                raise TypeError("sep option must be a string")
750            if binary_format:
751                raise ValueError("sep is not allowed with binary format")
752            if len(sep) != 1:
753                raise ValueError("sep must be a single one-byte character")
754            options.append('delimiter %s')
755            params.append(sep)
756        if null is not None:
757            if not isinstance(null, basestring):
758                raise TypeError("null option must be a string")
759            options.append('null %s')
760            params.append(null)
761        if decode is None:
762            if format == 'binary':
763                decode = False
764            else:
765                decode = str is unicode
766        else:
767            if not isinstance(decode, (int, bool)):
768                raise TypeError("decode option must be a boolean")
769            if decode and binary_format:
770                raise ValueError("decode is not allowed with binary format")
771        if columns:
772            if not isinstance(columns, basestring):
773                columns = ','.join('"%s"' % (col,) for col in columns)
774            operation.append('(%s)' % (columns,))
775
776        operation.append("to stdout")
777        if options:
778            operation.append('(%s)' % ','.join(options))
779        operation = ' '.join(operation)
780
781        getdata = self._src.getdata
782        self.execute(operation, params)
783
784        def copy():
785            self.rowcount = 0
786            while True:
787                row = getdata(decode)
788                if isinstance(row, int):
789                    if self.rowcount != row:
790                        self.rowcount = row
791                    break
792                self.rowcount += 1
793                yield row
794
795        if stream is None:
796            # no input stream, return the generator
797            return copy()
798
799        # write the rows to the file-like input stream
800        for row in copy():
801            write(row)
802
803        # return the cursor object, so you can chain operations
804        return self
805
806    def __next__(self):
807        """Return the next row (support for the iteration protocol)."""
808        res = self.fetchone()
809        if res is None:
810            raise StopIteration
811        return res
812
813    # Note that since Python 2.6 the iterator protocol uses __next()__
814    # instead of next(), we keep it only for backward compatibility of pgdb.
815    next = __next__
816
817    @staticmethod
818    def nextset():
819        """Not supported."""
820        raise NotSupportedError("nextset() is not supported")
821
822    @staticmethod
823    def setinputsizes(sizes):
824        """Not supported."""
825        pass  # unsupported, but silently passed
826
827    @staticmethod
828    def setoutputsize(size, column=0):
829        """Not supported."""
830        pass  # unsupported, but silently passed
831
832    @staticmethod
833    def row_factory(row):
834        """Process rows before they are returned.
835
836        You can overwrite this statically with a custom row factory, or
837        you can build a row factory dynamically with build_row_factory().
838
839        For example, you can create a Cursor class that returns rows as
840        Python dictionaries like this:
841
842            class DictCursor(pgdb.Cursor):
843
844                def row_factory(self, row):
845                    return {desc[0]: value
846                        for desc, value in zip(self.description, row)}
847
848            cur = DictCursor(con)  # get one DictCursor instance or
849            con.cursor_type = DictCursor  # always use DictCursor instances
850        """
851        raise NotImplementedError
852
853    def build_row_factory(self):
854        """Build a row factory based on the current description.
855
856        This implementation builds a row factory for creating named tuples.
857        You can overwrite this method if you want to dynamically create
858        different row factories whenever the column description changes.
859        """
860        colnames = self.colnames
861        if colnames:
862            try:
863                try:
864                    return namedtuple('Row', colnames, rename=True)._make
865                except TypeError:  # Python 2.6 and 3.0 do not support rename
866                    colnames = [v if v.isalnum() else 'column_%d' % n
867                             for n, v in enumerate(colnames)]
868                    return namedtuple('Row', colnames)._make
869            except ValueError:  # there is still a problem with the field names
870                colnames = ['column_%d' % n for n in range(len(colnames))]
871                return namedtuple('Row', colnames)._make
872
873
874CursorDescription = namedtuple('CursorDescription',
875    ['name', 'type_code', 'display_size', 'internal_size',
876     'precision', 'scale', 'null_ok'])
877
878
879### Connection Objects
880
881class Connection(object):
882    """Connection object."""
883
884    # expose the exceptions as attributes on the connection object
885    Error = Error
886    Warning = Warning
887    InterfaceError = InterfaceError
888    DatabaseError = DatabaseError
889    InternalError = InternalError
890    OperationalError = OperationalError
891    ProgrammingError = ProgrammingError
892    IntegrityError = IntegrityError
893    DataError = DataError
894    NotSupportedError = NotSupportedError
895
896    def __init__(self, cnx):
897        """Create a database connection object."""
898        self._cnx = cnx  # connection
899        self._tnx = False  # transaction state
900        self.type_cache = TypeCache(cnx)
901        self.cursor_type = Cursor
902        try:
903            self._cnx.source()
904        except Exception:
905            raise _op_error("invalid connection")
906
907    def __enter__(self):
908        """Enter the runtime context for the connection object.
909
910        The runtime context can be used for running transactions.
911        """
912        return self
913
914    def __exit__(self, et, ev, tb):
915        """Exit the runtime context for the connection object.
916
917        This does not close the connection, but it ends a transaction.
918        """
919        if et is None and ev is None and tb is None:
920            self.commit()
921        else:
922            self.rollback()
923
924    def close(self):
925        """Close the connection object."""
926        if self._cnx:
927            if self._tnx:
928                try:
929                    self.rollback()
930                except DatabaseError:
931                    pass
932            self._cnx.close()
933            self._cnx = None
934        else:
935            raise _op_error("connection has been closed")
936
937    def commit(self):
938        """Commit any pending transaction to the database."""
939        if self._cnx:
940            if self._tnx:
941                self._tnx = False
942                try:
943                    self._cnx.source().execute("COMMIT")
944                except DatabaseError:
945                    raise
946                except Exception:
947                    raise _op_error("can't commit")
948        else:
949            raise _op_error("connection has been closed")
950
951    def rollback(self):
952        """Roll back to the start of any pending transaction."""
953        if self._cnx:
954            if self._tnx:
955                self._tnx = False
956                try:
957                    self._cnx.source().execute("ROLLBACK")
958                except DatabaseError:
959                    raise
960                except Exception:
961                    raise _op_error("can't rollback")
962        else:
963            raise _op_error("connection has been closed")
964
965    def cursor(self):
966        """Return a new cursor object using the connection."""
967        if self._cnx:
968            try:
969                return self.cursor_type(self)
970            except Exception:
971                raise _op_error("invalid connection")
972        else:
973            raise _op_error("connection has been closed")
974
975    if shortcutmethods:  # otherwise do not implement and document this
976
977        def execute(self, operation, params=None):
978            """Shortcut method to run an operation on an implicit cursor."""
979            cursor = self.cursor()
980            cursor.execute(operation, params)
981            return cursor
982
983        def executemany(self, operation, param_seq):
984            """Shortcut method to run an operation against a sequence."""
985            cursor = self.cursor()
986            cursor.executemany(operation, param_seq)
987            return cursor
988
989
990### Module Interface
991
992_connect_ = connect
993
994def connect(dsn=None,
995        user=None, password=None,
996        host=None, database=None):
997    """Connect to a database."""
998    # first get params from DSN
999    dbport = -1
1000    dbhost = ""
1001    dbbase = ""
1002    dbuser = ""
1003    dbpasswd = ""
1004    dbopt = ""
1005    try:
1006        params = dsn.split(":")
1007        dbhost = params[0]
1008        dbbase = params[1]
1009        dbuser = params[2]
1010        dbpasswd = params[3]
1011        dbopt = params[4]
1012    except (AttributeError, IndexError, TypeError):
1013        pass
1014
1015    # override if necessary
1016    if user is not None:
1017        dbuser = user
1018    if password is not None:
1019        dbpasswd = password
1020    if database is not None:
1021        dbbase = database
1022    if host is not None:
1023        try:
1024            params = host.split(":")
1025            dbhost = params[0]
1026            dbport = int(params[1])
1027        except (AttributeError, IndexError, TypeError, ValueError):
1028            pass
1029
1030    # empty host is localhost
1031    if dbhost == "":
1032        dbhost = None
1033    if dbuser == "":
1034        dbuser = None
1035
1036    # open the connection
1037    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
1038    return Connection(cnx)
1039
1040
1041### Types Handling
1042
1043class Type(frozenset):
1044    """Type class for a couple of PostgreSQL data types.
1045
1046    PostgreSQL is object-oriented: types are dynamic.
1047    We must thus use type names as internal type codes.
1048    """
1049
1050    def __new__(cls, values):
1051        if isinstance(values, basestring):
1052            values = values.split()
1053        return super(Type, cls).__new__(cls, values)
1054
1055    def __eq__(self, other):
1056        if isinstance(other, basestring):
1057            if other.startswith('_'):
1058                other = other[1:]
1059            return other in self
1060        else:
1061            return super(Type, self).__eq__(other)
1062
1063    def __ne__(self, other):
1064        if isinstance(other, basestring):
1065            if other.startswith('_'):
1066                other = other[1:]
1067            return other not in self
1068        else:
1069            return super(Type, self).__ne__(other)
1070
1071
1072class ArrayType:
1073    """Type class for PostgreSQL array types."""
1074
1075    def __eq__(self, other):
1076        if isinstance(other, basestring):
1077            return other.startswith('_')
1078        else:
1079            return isinstance(other, ArrayType)
1080
1081    def __ne__(self, other):
1082        if isinstance(other, basestring):
1083            return not other.startswith('_')
1084        else:
1085            return not isinstance(other, ArrayType)
1086
1087
1088# Mandatory type objects defined by DB-API 2 specs:
1089
1090STRING = Type('char bpchar name text varchar')
1091BINARY = Type('bytea')
1092NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1093DATETIME = Type('date time timetz timestamp timestamptz interval'
1094    ' abstime reltime')  # these are very old
1095ROWID = Type('oid')
1096
1097
1098# Additional type objects (more specific):
1099
1100BOOL = Type('bool')
1101SMALLINT = Type('int2')
1102INTEGER = Type('int2 int4 int8 serial')
1103LONG = Type('int8')
1104FLOAT = Type('float4 float8')
1105NUMERIC = Type('numeric')
1106MONEY = Type('money')
1107DATE = Type('date')
1108TIME = Type('time timetz')
1109TIMESTAMP = Type('timestamp timestamptz')
1110INTERVAL = Type('interval')
1111JSON = Type('json jsonb')
1112
1113# Type object for arrays (also equate to their base types):
1114
1115ARRAY = ArrayType()
1116
1117
1118# Mandatory type helpers defined by DB-API 2 specs:
1119
1120def Date(year, month, day):
1121    """Construct an object holding a date value."""
1122    return date(year, month, day)
1123
1124
1125def Time(hour, minute=0, second=0, microsecond=0):
1126    """Construct an object holding a time value."""
1127    return time(hour, minute, second, microsecond)
1128
1129
1130def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1131    """Construct an object holding a time stamp value."""
1132    return datetime(year, month, day, hour, minute, second, microsecond)
1133
1134
1135def DateFromTicks(ticks):
1136    """Construct an object holding a date value from the given ticks value."""
1137    return Date(*localtime(ticks)[:3])
1138
1139
1140def TimeFromTicks(ticks):
1141    """Construct an object holding a time value from the given ticks value."""
1142    return Time(*localtime(ticks)[3:6])
1143
1144
1145def TimestampFromTicks(ticks):
1146    """Construct an object holding a time stamp from the given ticks value."""
1147    return Timestamp(*localtime(ticks)[:6])
1148
1149
1150class Binary(bytes):
1151    """Construct an object capable of holding a binary (long) string value."""
1152
1153
1154# Additional type helpers for PyGreSQL:
1155
1156class Json:
1157    """Construct a wrapper for holding an object serializable to JSON."""
1158
1159    def __init__(self, obj, encode=None):
1160        self.obj = obj
1161        self.encode = encode or jsonencode
1162
1163    def __str__(self):
1164        obj = self.obj
1165        if isinstance(obj, basestring):
1166            return obj
1167        return self.encode(obj)
1168
1169    __pg_repr__ = __str__
1170
1171
1172# If run as script, print some information:
1173
1174if __name__ == '__main__':
1175    print('PyGreSQL version', version)
1176    print('')
1177    print(__doc__)
Note: See TracBrowser for help on using the repository browser.