source: trunk/pgdb.py @ 784

Last change on this file since 784 was 784, checked in by cito, 4 years ago

Make type cache and cursor description more useful

The type cache now stores some more information, e.g. whether a type is a base
type or a composite type and the category of the type. This may be later used
for casting composite types, or exposed to the user.

The cursor description now contains proper information on the size of numeric
types (including precision and scale).

  • Property svn:keywords set to Id
File size: 35.1 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 784 2016-01-26 17:17:23Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from re import compile as regex
76from json import loads as jsondecode, dumps as jsonencode
77
78try:
79    long
80except NameError:  # Python >= 3.0
81    long = int
82
83try:
84    unicode
85except NameError:  # Python >= 3.0
86    unicode = str
87
88try:
89    basestring
90except NameError:  # Python >= 3.0
91    basestring = (str, bytes)
92
93from collections import Iterable
94try:
95    from collections import OrderedDict
96except ImportError:  # Python 2.6 or 3.0
97    try:
98        from ordereddict import OrderedDict
99    except Exception:
100        def OrderedDict(*args):
101            raise NotSupportedError('OrderedDict is not supported')
102
103
104### Module Constants
105
106# compliant with DB API 2.0
107apilevel = '2.0'
108
109# module may be shared, but not connections
110threadsafety = 1
111
112# this module use extended python format codes
113paramstyle = 'pyformat'
114
115# shortcut methods have been excluded from DB API 2 and
116# are not recommended by the DB SIG, but they can be handy
117shortcutmethods = 1
118
119
120### Internal Types Handling
121
122def decimal_type(decimal_type=None):
123    """Get or set global type to be used for decimal values."""
124    global Decimal
125    if decimal_type is not None:
126        _cast['numeric'] = Decimal = decimal_type
127    return Decimal
128
129
130def _cast_bool(value):
131    return value[:1] in ('t', 'T')
132
133
134def _cast_money(value):
135    return Decimal(''.join(filter(
136        lambda v: v in '0123456789.-', value)))
137
138
139_cast = {'bool': _cast_bool, 'bytea': unescape_bytea,
140    'int2': int, 'int4': int, 'serial': int,
141    'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
142    'oid': long, 'oid8': long,
143    'float4': float, 'float8': float,
144    'numeric': Decimal, 'money': _cast_money}
145
146
147def _db_error(msg, cls=DatabaseError):
148    """Return DatabaseError with empty sqlstate attribute."""
149    error = cls(msg)
150    error.sqlstate = None
151    return error
152
153
154def _op_error(msg):
155    """Return OperationalError."""
156    return _db_error(msg, OperationalError)
157
158
159TypeInfo = namedtuple('TypeInfo',
160    ['oid', 'name', 'len', 'type', 'category', 'delim', 'relid'])
161
162
163class TypeCache(dict):
164    """Cache for database types.
165
166    This cache maps type OIDs to TypeInfo tuples containing the name
167    and other important info for the database type with the given OID.
168    """
169
170    def __init__(self, cnx):
171        """Initialize type cache for connection."""
172        super(TypeCache, self).__init__()
173        self._escape_string = cnx.escape_string
174        self._src = cnx.source()
175
176    def __missing__(self, key):
177        q = ("SELECT oid, typname,"
178             " typlen, typtype, typcategory, typdelim, typrelid"
179            " FROM pg_type WHERE ")
180        if isinstance(key, int):
181            q += "oid = %d" % key
182        else:
183            q += "typname = '%s'" % self._escape_string(key)
184        self._src.execute(q)
185        res = list(self._src.fetch(1)[0])
186        res[0] = int(res[0])
187        res[2] = int(res[2])
188        res[6] = int(res[6])
189        res = TypeInfo(*res)
190        self[res.oid] = self[res.name] = res
191        return res
192
193    @staticmethod
194    def typecast(typ, value):
195        """Cast value according to database type."""
196        if value is None:
197            # for NULL values, no typecast is necessary
198            return None
199        cast = _cast.get(typ)
200        if cast is None:
201            if typ.startswith('_'):
202                # cast as an array type
203                cast = _cast.get(typ[1:])
204                return cast_array(value, cast)
205            # no typecast available or necessary
206            return value
207        else:
208            return cast(value)
209
210
211_re_array_escape = regex(r'(["\\])')
212_re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
213
214
215class _quotedict(dict):
216    """Dictionary with auto quoting of its items.
217
218    The quote attribute must be set to the desired quote function.
219    """
220
221    def __getitem__(self, key):
222        return self.quote(super(_quotedict, self).__getitem__(key))
223
224
225### Cursor Object
226
227class Cursor(object):
228    """Cursor object."""
229
230    def __init__(self, dbcnx):
231        """Create a cursor object for the database connection."""
232        self.connection = self._dbcnx = dbcnx
233        self._cnx = dbcnx._cnx
234        self._type_cache = dbcnx._type_cache
235        self._src = self._cnx.source()
236        # the official attribute for describing the result columns
237        self.description = None
238        # unofficial attributes for convenience and performance
239        self.colnames = self.coltypes = None
240        if self.row_factory is Cursor.row_factory:
241            # the row factory needs to be determined dynamically
242            self.row_factory = None
243        else:
244            self.build_row_factory = None
245        self.rowcount = -1
246        self.arraysize = 1
247        self.lastrowid = None
248
249    def __iter__(self):
250        """Make cursor compatible to the iteration protocol."""
251        return self
252
253    def __enter__(self):
254        """Enter the runtime context for the cursor object."""
255        return self
256
257    def __exit__(self, et, ev, tb):
258        """Exit the runtime context for the cursor object."""
259        self.close()
260
261    def _quote(self, val):
262        """Quote value depending on its type."""
263        if val is None:
264            return 'NULL'
265        if isinstance(val, (datetime, date, time, timedelta, Json)):
266            val = str(val)
267        if isinstance(val, basestring):
268            if isinstance(val, Binary):
269                val = self._cnx.escape_bytea(val)
270                if bytes is not str:  # Python >= 3.0
271                    val = val.decode('ascii')
272            else:
273                val = self._cnx.escape_string(val)
274            return "'%s'" % val
275        if isinstance(val, float):
276            if isinf(val):
277                return "'-Infinity'" if val < 0 else "'Infinity'"
278            if isnan(val):
279                return "'NaN'"
280            return val
281        if isinstance(val, (int, long, Decimal)):
282            return val
283        if isinstance(val, list):
284            return "'%s'" % self._quote_array(val)
285        if isinstance(val, tuple):
286            q = self._quote
287            return 'ROW(%s)' % ','.join(str(q(v)) for v in val)
288        try:
289            return val.__pg_repr__()
290        except AttributeError:
291            raise InterfaceError(
292                'do not know how to handle type %s' % type(val))
293
294    def _quote_array(self, val):
295        """Quote value as a literal constant for an array."""
296        # We could also cast to an array constructor here, but that is more
297        # verbose and you need to know the base type to build emtpy arrays.
298        if isinstance(val, list):
299            return '{%s}' % ','.join(self._quote_array(v) for v in val)
300        if val is None:
301            return 'null'
302        if isinstance(val, (int, long, float)):
303            return str(val)
304        if isinstance(val, bool):
305            return 't' if val else 'f'
306        if isinstance(val, basestring):
307            if not val:
308                return '""'
309            if _re_array_quote.search(val):
310                return '"%s"' % _re_array_escape.sub(r'\\\1', val)
311            return val
312        try:
313            return val.__pg_repr__()
314        except AttributeError:
315            raise InterfaceError(
316                'do not know how to handle type %s' % type(val))
317
318    def _quoteparams(self, string, parameters):
319        """Quote parameters.
320
321        This function works for both mappings and sequences.
322        """
323        if isinstance(parameters, dict):
324            parameters = _quotedict(parameters)
325            parameters.quote = self._quote
326        else:
327            parameters = tuple(map(self._quote, parameters))
328        return string % parameters
329
330    def _make_description(self, info):
331        """Make the description tuple for the given field info."""
332        name, typ, size, mod = info[1:]
333        type_info = self._type_cache[typ]
334        type_code = type_info.name
335        if mod > 0:
336            mod -= 4
337        if type_code == 'numeric':
338            precision, scale = mod >> 16, mod & 0xffff
339            size = precision
340        else:
341            if not size:
342                size = type_info.size
343            if size == -1:
344                size = mod
345            precision = scale = None
346        return CursorDescription(name, type_code,
347            None, size, precision, scale, None)
348
349    def close(self):
350        """Close the cursor object."""
351        self._src.close()
352        self.description = None
353        self.colnames = self.coltypes = None
354        self.rowcount = -1
355        self.lastrowid = None
356
357    def execute(self, operation, parameters=None):
358        """Prepare and execute a database operation (query or command)."""
359        # The parameters may also be specified as list of tuples to e.g.
360        # insert multiple rows in a single operation, but this kind of
361        # usage is deprecated.  We make several plausibility checks because
362        # tuples can also be passed with the meaning of ROW constructors.
363        if (parameters and isinstance(parameters, list)
364                and len(parameters) > 1
365                and all(isinstance(p, tuple) for p in parameters)
366                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
367            return self.executemany(operation, parameters)
368        else:
369            # not a list of tuples
370            return self.executemany(operation, [parameters])
371
372    def executemany(self, operation, seq_of_parameters):
373        """Prepare operation and execute it against a parameter sequence."""
374        if not seq_of_parameters:
375            # don't do anything without parameters
376            return
377        self.description = None
378        self.colnames = self.coltypes = None
379        self.rowcount = -1
380        # first try to execute all queries
381        rowcount = 0
382        sql = "BEGIN"
383        try:
384            if not self._dbcnx._tnx:
385                try:
386                    self._cnx.source().execute(sql)
387                except DatabaseError:
388                    raise  # database provides error message
389                except Exception as err:
390                    raise _op_error("can't start transaction")
391                self._dbcnx._tnx = True
392            for parameters in seq_of_parameters:
393                sql = operation
394                if parameters:
395                    sql = self._quoteparams(sql, parameters)
396                rows = self._src.execute(sql)
397                if rows:  # true if not DML
398                    rowcount += rows
399                else:
400                    self.rowcount = -1
401        except DatabaseError:
402            raise  # database provides error message
403        except Error as err:
404            raise _db_error(
405                "error in '%s': '%s' " % (sql, err), InterfaceError)
406        except Exception as err:
407            raise _op_error("internal error in '%s': %s" % (sql, err))
408        # then initialize result raw count and description
409        if self._src.resulttype == RESULT_DQL:
410            self.rowcount = self._src.ntuples
411            description = self._make_description
412            description = [description(info) for info in self._src.listinfo()]
413            self.colnames = [d[0] for d in description]
414            self.coltypes = [d[1] for d in description]
415            self.description = description
416            self.lastrowid = None
417            if self.build_row_factory:
418                self.row_factory = self.build_row_factory()
419        else:
420            self.rowcount = rowcount
421            self.lastrowid = self._src.oidstatus()
422        # return the cursor object, so you can write statements such as
423        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
424        return self
425
426    def fetchone(self):
427        """Fetch the next row of a query result set."""
428        res = self.fetchmany(1, False)
429        try:
430            return res[0]
431        except IndexError:
432            return None
433
434    def fetchall(self):
435        """Fetch all (remaining) rows of a query result."""
436        return self.fetchmany(-1, False)
437
438    def fetchmany(self, size=None, keep=False):
439        """Fetch the next set of rows of a query result.
440
441        The number of rows to fetch per call is specified by the
442        size parameter. If it is not given, the cursor's arraysize
443        determines the number of rows to be fetched. If you set
444        the keep parameter to true, this is kept as new arraysize.
445        """
446        if size is None:
447            size = self.arraysize
448        if keep:
449            self.arraysize = size
450        try:
451            result = self._src.fetch(size)
452        except DatabaseError:
453            raise
454        except Error as err:
455            raise _db_error(str(err))
456        typecast = self._type_cache.typecast
457        return [self.row_factory([typecast(typ, value)
458            for typ, value in zip(self.coltypes, row)]) for row in result]
459
460    def callproc(self, procname, parameters=None):
461        """Call a stored database procedure with the given name.
462
463        The sequence of parameters must contain one entry for each input
464        argument that the procedure expects. The result of the call is the
465        same as this input sequence; replacement of output and input/output
466        parameters in the return value is currently not supported.
467
468        The procedure may also provide a result set as output. These can be
469        requested through the standard fetch methods of the cursor.
470        """
471        n = parameters and len(parameters) or 0
472        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
473        self.execute(query, parameters)
474        return parameters
475
476    def copy_from(self, stream, table,
477            format=None, sep=None, null=None, size=None, columns=None):
478        """Copy data from an input stream to the specified table.
479
480        The input stream can be a file-like object with a read() method or
481        it can also be an iterable returning a row or multiple rows of input
482        on each iteration.
483
484        The format must be text, csv or binary. The sep option sets the
485        column separator (delimiter) used in the non binary formats.
486        The null option sets the textual representation of NULL in the input.
487
488        The size option sets the size of the buffer used when reading data
489        from file-like objects.
490
491        The copy operation can be restricted to a subset of columns. If no
492        columns are specified, all of them will be copied.
493        """
494        binary_format = format == 'binary'
495        try:
496            read = stream.read
497        except AttributeError:
498            if size:
499                raise ValueError("size must only be set for file-like objects")
500            if binary_format:
501                input_type = bytes
502                type_name = 'byte strings'
503            else:
504                input_type = basestring
505                type_name = 'strings'
506
507            if isinstance(stream, basestring):
508                if not isinstance(stream, input_type):
509                    raise ValueError("The input must be %s" % type_name)
510                if not binary_format:
511                    if isinstance(stream, str):
512                        if not stream.endswith('\n'):
513                            stream += '\n'
514                    else:
515                        if not stream.endswith(b'\n'):
516                            stream += b'\n'
517
518                def chunks():
519                    yield stream
520
521            elif isinstance(stream, Iterable):
522
523                def chunks():
524                    for chunk in stream:
525                        if not isinstance(chunk, input_type):
526                            raise ValueError(
527                                "Input stream must consist of %s" % type_name)
528                        if isinstance(chunk, str):
529                            if not chunk.endswith('\n'):
530                                chunk += '\n'
531                        else:
532                            if not chunk.endswith(b'\n'):
533                                chunk += b'\n'
534                        yield chunk
535
536            else:
537                raise TypeError("Need an input stream to copy from")
538        else:
539            if size is None:
540                size = 8192
541            elif not isinstance(size, int):
542                raise TypeError("The size option must be an integer")
543            if size > 0:
544
545                def chunks():
546                    while True:
547                        buffer = read(size)
548                        yield buffer
549                        if not buffer or len(buffer) < size:
550                            break
551
552            else:
553
554                def chunks():
555                    yield read()
556
557        if not table or not isinstance(table, basestring):
558            raise TypeError("Need a table to copy to")
559        if table.lower().startswith('select'):
560                raise ValueError("Must specify a table, not a query")
561        else:
562            table = '"%s"' % (table,)
563        operation = ['copy %s' % (table,)]
564        options = []
565        params = []
566        if format is not None:
567            if not isinstance(format, basestring):
568                raise TypeError("format option must be be a string")
569            if format not in ('text', 'csv', 'binary'):
570                raise ValueError("invalid format")
571            options.append('format %s' % (format,))
572        if sep is not None:
573            if not isinstance(sep, basestring):
574                raise TypeError("sep option must be a string")
575            if format == 'binary':
576                raise ValueError("sep is not allowed with binary format")
577            if len(sep) != 1:
578                raise ValueError("sep must be a single one-byte character")
579            options.append('delimiter %s')
580            params.append(sep)
581        if null is not None:
582            if not isinstance(null, basestring):
583                raise TypeError("null option must be a string")
584            options.append('null %s')
585            params.append(null)
586        if columns:
587            if not isinstance(columns, basestring):
588                columns = ','.join('"%s"' % (col,) for col in columns)
589            operation.append('(%s)' % (columns,))
590        operation.append("from stdin")
591        if options:
592            operation.append('(%s)' % ','.join(options))
593        operation = ' '.join(operation)
594
595        putdata = self._src.putdata
596        self.execute(operation, params)
597
598        try:
599            for chunk in chunks():
600                putdata(chunk)
601        except BaseException as error:
602            self.rowcount = -1
603            # the following call will re-raise the error
604            putdata(error)
605        else:
606            self.rowcount = putdata(None)
607
608        # return the cursor object, so you can chain operations
609        return self
610
611    def copy_to(self, stream, table,
612            format=None, sep=None, null=None, decode=None, columns=None):
613        """Copy data from the specified table to an output stream.
614
615        The output stream can be a file-like object with a write() method or
616        it can also be None, in which case the method will return a generator
617        yielding a row on each iteration.
618
619        Output will be returned as byte strings unless you set decode to true.
620
621        Note that you can also use a select query instead of the table name.
622
623        The format must be text, csv or binary. The sep option sets the
624        column separator (delimiter) used in the non binary formats.
625        The null option sets the textual representation of NULL in the output.
626
627        The copy operation can be restricted to a subset of columns. If no
628        columns are specified, all of them will be copied.
629        """
630        binary_format = format == 'binary'
631        if stream is not None:
632            try:
633                write = stream.write
634            except AttributeError:
635                raise TypeError("need an output stream to copy to")
636        if not table or not isinstance(table, basestring):
637            raise TypeError("need a table to copy to")
638        if table.lower().startswith('select'):
639            if columns:
640                raise ValueError("columns must be specified in the query")
641            table = '(%s)' % (table,)
642        else:
643            table = '"%s"' % (table,)
644        operation = ['copy %s' % (table,)]
645        options = []
646        params = []
647        if format is not None:
648            if not isinstance(format, basestring):
649                raise TypeError("format option must be a string")
650            if format not in ('text', 'csv', 'binary'):
651                raise ValueError("invalid format")
652            options.append('format %s' % (format,))
653        if sep is not None:
654            if not isinstance(sep, basestring):
655                raise TypeError("sep option must be a string")
656            if binary_format:
657                raise ValueError("sep is not allowed with binary format")
658            if len(sep) != 1:
659                raise ValueError("sep must be a single one-byte character")
660            options.append('delimiter %s')
661            params.append(sep)
662        if null is not None:
663            if not isinstance(null, basestring):
664                raise TypeError("null option must be a string")
665            options.append('null %s')
666            params.append(null)
667        if decode is None:
668            if format == 'binary':
669                decode = False
670            else:
671                decode = str is unicode
672        else:
673            if not isinstance(decode, (int, bool)):
674                raise TypeError("decode option must be a boolean")
675            if decode and binary_format:
676                raise ValueError("decode is not allowed with binary format")
677        if columns:
678            if not isinstance(columns, basestring):
679                columns = ','.join('"%s"' % (col,) for col in columns)
680            operation.append('(%s)' % (columns,))
681
682        operation.append("to stdout")
683        if options:
684            operation.append('(%s)' % ','.join(options))
685        operation = ' '.join(operation)
686
687        getdata = self._src.getdata
688        self.execute(operation, params)
689
690        def copy():
691            self.rowcount = 0
692            while True:
693                row = getdata(decode)
694                if isinstance(row, int):
695                    if self.rowcount != row:
696                        self.rowcount = row
697                    break
698                self.rowcount += 1
699                yield row
700
701        if stream is None:
702            # no input stream, return the generator
703            return copy()
704
705        # write the rows to the file-like input stream
706        for row in copy():
707            write(row)
708
709        # return the cursor object, so you can chain operations
710        return self
711
712    def __next__(self):
713        """Return the next row (support for the iteration protocol)."""
714        res = self.fetchone()
715        if res is None:
716            raise StopIteration
717        return res
718
719    # Note that since Python 2.6 the iterator protocol uses __next()__
720    # instead of next(), we keep it only for backward compatibility of pgdb.
721    next = __next__
722
723    @staticmethod
724    def nextset():
725        """Not supported."""
726        raise NotSupportedError("nextset() is not supported")
727
728    @staticmethod
729    def setinputsizes(sizes):
730        """Not supported."""
731        pass  # unsupported, but silently passed
732
733    @staticmethod
734    def setoutputsize(size, column=0):
735        """Not supported."""
736        pass  # unsupported, but silently passed
737
738    @staticmethod
739    def row_factory(row):
740        """Process rows before they are returned.
741
742        You can overwrite this statically with a custom row factory, or
743        you can build a row factory dynamically with build_row_factory().
744
745        For example, you can create a Cursor class that returns rows as
746        Python dictionaries like this:
747
748            class DictCursor(pgdb.Cursor):
749
750                def row_factory(self, row):
751                    return {desc[0]: value
752                        for desc, value in zip(self.description, row)}
753
754            cur = DictCursor(con)  # get one DictCursor instance or
755            con.cursor_type = DictCursor  # always use DictCursor instances
756        """
757        raise NotImplementedError
758
759    def build_row_factory(self):
760        """Build a row factory based on the current description.
761
762        This implementation builds a row factory for creating named tuples.
763        You can overwrite this method if you want to dynamically create
764        different row factories whenever the column description changes.
765        """
766        colnames = self.colnames
767        if colnames:
768            try:
769                try:
770                    return namedtuple('Row', colnames, rename=True)._make
771                except TypeError:  # Python 2.6 and 3.0 do not support rename
772                    colnames = [v if v.isalnum() else 'column_%d' % n
773                             for n, v in enumerate(colnames)]
774                    return namedtuple('Row', colnames)._make
775            except ValueError:  # there is still a problem with the field names
776                colnames = ['column_%d' % n for n in range(len(colnames))]
777                return namedtuple('Row', colnames)._make
778
779
780CursorDescription = namedtuple('CursorDescription',
781    ['name', 'type_code', 'display_size', 'internal_size',
782     'precision', 'scale', 'null_ok'])
783
784
785### Connection Objects
786
787class Connection(object):
788    """Connection object."""
789
790    # expose the exceptions as attributes on the connection object
791    Error = Error
792    Warning = Warning
793    InterfaceError = InterfaceError
794    DatabaseError = DatabaseError
795    InternalError = InternalError
796    OperationalError = OperationalError
797    ProgrammingError = ProgrammingError
798    IntegrityError = IntegrityError
799    DataError = DataError
800    NotSupportedError = NotSupportedError
801
802    def __init__(self, cnx):
803        """Create a database connection object."""
804        self._cnx = cnx  # connection
805        self._tnx = False  # transaction state
806        self._type_cache = TypeCache(cnx)
807        self.cursor_type = Cursor
808        try:
809            self._cnx.source()
810        except Exception:
811            raise _op_error("invalid connection")
812
813    def __enter__(self):
814        """Enter the runtime context for the connection object.
815
816        The runtime context can be used for running transactions.
817        """
818        return self
819
820    def __exit__(self, et, ev, tb):
821        """Exit the runtime context for the connection object.
822
823        This does not close the connection, but it ends a transaction.
824        """
825        if et is None and ev is None and tb is None:
826            self.commit()
827        else:
828            self.rollback()
829
830    def close(self):
831        """Close the connection object."""
832        if self._cnx:
833            if self._tnx:
834                try:
835                    self.rollback()
836                except DatabaseError:
837                    pass
838            self._cnx.close()
839            self._cnx = None
840        else:
841            raise _op_error("connection has been closed")
842
843    def commit(self):
844        """Commit any pending transaction to the database."""
845        if self._cnx:
846            if self._tnx:
847                self._tnx = False
848                try:
849                    self._cnx.source().execute("COMMIT")
850                except DatabaseError:
851                    raise
852                except Exception:
853                    raise _op_error("can't commit")
854        else:
855            raise _op_error("connection has been closed")
856
857    def rollback(self):
858        """Roll back to the start of any pending transaction."""
859        if self._cnx:
860            if self._tnx:
861                self._tnx = False
862                try:
863                    self._cnx.source().execute("ROLLBACK")
864                except DatabaseError:
865                    raise
866                except Exception:
867                    raise _op_error("can't rollback")
868        else:
869            raise _op_error("connection has been closed")
870
871    def cursor(self):
872        """Return a new cursor object using the connection."""
873        if self._cnx:
874            try:
875                return self.cursor_type(self)
876            except Exception:
877                raise _op_error("invalid connection")
878        else:
879            raise _op_error("connection has been closed")
880
881    if shortcutmethods:  # otherwise do not implement and document this
882
883        def execute(self, operation, params=None):
884            """Shortcut method to run an operation on an implicit cursor."""
885            cursor = self.cursor()
886            cursor.execute(operation, params)
887            return cursor
888
889        def executemany(self, operation, param_seq):
890            """Shortcut method to run an operation against a sequence."""
891            cursor = self.cursor()
892            cursor.executemany(operation, param_seq)
893            return cursor
894
895
896### Module Interface
897
898_connect_ = connect
899
900def connect(dsn=None,
901        user=None, password=None,
902        host=None, database=None):
903    """Connect to a database."""
904    # first get params from DSN
905    dbport = -1
906    dbhost = ""
907    dbbase = ""
908    dbuser = ""
909    dbpasswd = ""
910    dbopt = ""
911    try:
912        params = dsn.split(":")
913        dbhost = params[0]
914        dbbase = params[1]
915        dbuser = params[2]
916        dbpasswd = params[3]
917        dbopt = params[4]
918    except (AttributeError, IndexError, TypeError):
919        pass
920
921    # override if necessary
922    if user is not None:
923        dbuser = user
924    if password is not None:
925        dbpasswd = password
926    if database is not None:
927        dbbase = database
928    if host is not None:
929        try:
930            params = host.split(":")
931            dbhost = params[0]
932            dbport = int(params[1])
933        except (AttributeError, IndexError, TypeError, ValueError):
934            pass
935
936    # empty host is localhost
937    if dbhost == "":
938        dbhost = None
939    if dbuser == "":
940        dbuser = None
941
942    # open the connection
943    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
944    return Connection(cnx)
945
946
947### Types Handling
948
949class Type(frozenset):
950    """Type class for a couple of PostgreSQL data types.
951
952    PostgreSQL is object-oriented: types are dynamic.
953    We must thus use type names as internal type codes.
954    """
955
956    def __new__(cls, values):
957        if isinstance(values, basestring):
958            values = values.split()
959        return super(Type, cls).__new__(cls, values)
960
961    def __eq__(self, other):
962        if isinstance(other, basestring):
963            if other.startswith('_'):
964                other = other[1:]
965            return other in self
966        else:
967            return super(Type, self).__eq__(other)
968
969    def __ne__(self, other):
970        if isinstance(other, basestring):
971            if other.startswith('_'):
972                other = other[1:]
973            return other not in self
974        else:
975            return super(Type, self).__ne__(other)
976
977
978class ArrayType:
979    """Type class for PostgreSQL array types."""
980
981    def __eq__(self, other):
982        if isinstance(other, basestring):
983            return other.startswith('_')
984        else:
985            return isinstance(other, ArrayType)
986
987    def __ne__(self, other):
988        if isinstance(other, basestring):
989            return not other.startswith('_')
990        else:
991            return not isinstance(other, ArrayType)
992
993
994# Mandatory type objects defined by DB-API 2 specs:
995
996STRING = Type('char bpchar name text varchar')
997BINARY = Type('bytea')
998NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
999DATETIME = Type('date time timetz timestamp timestamptz interval'
1000    ' abstime reltime')  # these are very old
1001ROWID = Type('oid')
1002
1003
1004# Additional type objects (more specific):
1005
1006BOOL = Type('bool')
1007SMALLINT = Type('int2')
1008INTEGER = Type('int2 int4 int8 serial')
1009LONG = Type('int8')
1010FLOAT = Type('float4 float8')
1011NUMERIC = Type('numeric')
1012MONEY = Type('money')
1013DATE = Type('date')
1014TIME = Type('time timetz')
1015TIMESTAMP = Type('timestamp timestamptz')
1016INTERVAL = Type('interval')
1017JSON = Type('json jsonb')
1018
1019# Type object for arrays (also equate to their base types):
1020
1021ARRAY = ArrayType()
1022
1023
1024# Mandatory type helpers defined by DB-API 2 specs:
1025
1026def Date(year, month, day):
1027    """Construct an object holding a date value."""
1028    return date(year, month, day)
1029
1030
1031def Time(hour, minute=0, second=0, microsecond=0):
1032    """Construct an object holding a time value."""
1033    return time(hour, minute, second, microsecond)
1034
1035
1036def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1037    """Construct an object holding a time stamp value."""
1038    return datetime(year, month, day, hour, minute, second, microsecond)
1039
1040
1041def DateFromTicks(ticks):
1042    """Construct an object holding a date value from the given ticks value."""
1043    return Date(*localtime(ticks)[:3])
1044
1045
1046def TimeFromTicks(ticks):
1047    """Construct an object holding a time value from the given ticks value."""
1048    return Time(*localtime(ticks)[3:6])
1049
1050
1051def TimestampFromTicks(ticks):
1052    """Construct an object holding a time stamp from the given ticks value."""
1053    return Timestamp(*localtime(ticks)[:6])
1054
1055
1056class Binary(bytes):
1057    """Construct an object capable of holding a binary (long) string value."""
1058
1059
1060# Additional type helpers for PyGreSQL:
1061
1062class Json:
1063    """Construct a wrapper for holding an object serializable to JSON."""
1064
1065    def __init__(self, obj, encode=None):
1066        self.obj = obj
1067        self.encode = encode or jsonencode
1068
1069    def __str__(self):
1070        obj = self.obj
1071        if isinstance(obj, basestring):
1072            return obj
1073        return self.encode(obj)
1074
1075    __pg_repr__ = __str__
1076
1077
1078# If run as script, print some information:
1079
1080if __name__ == '__main__':
1081    print('PyGreSQL version', version)
1082    print('')
1083    print(__doc__)
Note: See TracBrowser for help on using the repository browser.