source: trunk/pgdb.py @ 786

Last change on this file since 786 was 786, checked in by cito, 4 years ago

Implement the cursor description as a property

It's better to create the description only on demand.

  • Property svn:keywords set to Id
File size: 35.4 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 786 2016-01-26 18:16:29Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from re import compile as regex
76from json import loads as jsondecode, dumps as jsonencode
77
78try:
79    long
80except NameError:  # Python >= 3.0
81    long = int
82
83try:
84    unicode
85except NameError:  # Python >= 3.0
86    unicode = str
87
88try:
89    basestring
90except NameError:  # Python >= 3.0
91    basestring = (str, bytes)
92
93from collections import Iterable
94try:
95    from collections import OrderedDict
96except ImportError:  # Python 2.6 or 3.0
97    try:
98        from ordereddict import OrderedDict
99    except Exception:
100        def OrderedDict(*args):
101            raise NotSupportedError('OrderedDict is not supported')
102
103
104### Module Constants
105
106# compliant with DB API 2.0
107apilevel = '2.0'
108
109# module may be shared, but not connections
110threadsafety = 1
111
112# this module use extended python format codes
113paramstyle = 'pyformat'
114
115# shortcut methods have been excluded from DB API 2 and
116# are not recommended by the DB SIG, but they can be handy
117shortcutmethods = 1
118
119
120### Internal Types Handling
121
122def decimal_type(decimal_type=None):
123    """Get or set global type to be used for decimal values."""
124    global Decimal
125    if decimal_type is not None:
126        _cast['numeric'] = Decimal = decimal_type
127    return Decimal
128
129
130def _cast_bool(value):
131    return value[:1] in ('t', 'T')
132
133
134def _cast_money(value):
135    return Decimal(''.join(filter(
136        lambda v: v in '0123456789.-', value)))
137
138
139_cast = {'bool': _cast_bool, 'bytea': unescape_bytea,
140    'int2': int, 'int4': int, 'serial': int,
141    'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
142    'oid': long, 'oid8': long,
143    'float4': float, 'float8': float,
144    'numeric': Decimal, 'money': _cast_money}
145
146
147def _db_error(msg, cls=DatabaseError):
148    """Return DatabaseError with empty sqlstate attribute."""
149    error = cls(msg)
150    error.sqlstate = None
151    return error
152
153
154def _op_error(msg):
155    """Return OperationalError."""
156    return _db_error(msg, OperationalError)
157
158
159TypeInfo = namedtuple('TypeInfo',
160    ['oid', 'name', 'len', 'type', 'category', 'delim', 'relid'])
161
162
163class TypeCache(dict):
164    """Cache for database types.
165
166    This cache maps type OIDs to TypeInfo tuples containing the name
167    and other important info for the database type with the given OID.
168    """
169
170    def __init__(self, cnx):
171        """Initialize type cache for connection."""
172        super(TypeCache, self).__init__()
173        self._escape_string = cnx.escape_string
174        self._src = cnx.source()
175
176    def __missing__(self, key):
177        q = ("SELECT oid, typname,"
178             " typlen, typtype, typcategory, typdelim, typrelid"
179            " FROM pg_type WHERE ")
180        if isinstance(key, int):
181            q += "oid = %d" % key
182        else:
183            q += "typname = '%s'" % self._escape_string(key)
184        self._src.execute(q)
185        res = list(self._src.fetch(1)[0])
186        res[0] = int(res[0])
187        res[2] = int(res[2])
188        res[6] = int(res[6])
189        res = TypeInfo(*res)
190        self[res.oid] = self[res.name] = res
191        return res
192
193    @staticmethod
194    def typecast(typ, value):
195        """Cast value according to database type."""
196        if value is None:
197            # for NULL values, no typecast is necessary
198            return None
199        cast = _cast.get(typ)
200        if cast is None:
201            if typ.startswith('_'):
202                # cast as an array type
203                cast = _cast.get(typ[1:])
204                return cast_array(value, cast)
205            # no typecast available or necessary
206            return value
207        else:
208            return cast(value)
209
210
211_re_array_escape = regex(r'(["\\])')
212_re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
213
214
215class _quotedict(dict):
216    """Dictionary with auto quoting of its items.
217
218    The quote attribute must be set to the desired quote function.
219    """
220
221    def __getitem__(self, key):
222        return self.quote(super(_quotedict, self).__getitem__(key))
223
224
225### Cursor Object
226
227class Cursor(object):
228    """Cursor object."""
229
230    def __init__(self, dbcnx):
231        """Create a cursor object for the database connection."""
232        self.connection = self._dbcnx = dbcnx
233        self._cnx = dbcnx._cnx
234        self._type_cache = dbcnx._type_cache
235        self._src = self._cnx.source()
236        # the official attribute for describing the result columns
237        self._description = None
238        if self.row_factory is Cursor.row_factory:
239            # the row factory needs to be determined dynamically
240            self.row_factory = None
241        else:
242            self.build_row_factory = None
243        self.rowcount = -1
244        self.arraysize = 1
245        self.lastrowid = None
246
247    def __iter__(self):
248        """Make cursor compatible to the iteration protocol."""
249        return self
250
251    def __enter__(self):
252        """Enter the runtime context for the cursor object."""
253        return self
254
255    def __exit__(self, et, ev, tb):
256        """Exit the runtime context for the cursor object."""
257        self.close()
258
259    def _quote(self, val):
260        """Quote value depending on its type."""
261        if val is None:
262            return 'NULL'
263        if isinstance(val, (datetime, date, time, timedelta, Json)):
264            val = str(val)
265        if isinstance(val, basestring):
266            if isinstance(val, Binary):
267                val = self._cnx.escape_bytea(val)
268                if bytes is not str:  # Python >= 3.0
269                    val = val.decode('ascii')
270            else:
271                val = self._cnx.escape_string(val)
272            return "'%s'" % val
273        if isinstance(val, float):
274            if isinf(val):
275                return "'-Infinity'" if val < 0 else "'Infinity'"
276            if isnan(val):
277                return "'NaN'"
278            return val
279        if isinstance(val, (int, long, Decimal)):
280            return val
281        if isinstance(val, list):
282            return "'%s'" % self._quote_array(val)
283        if isinstance(val, tuple):
284            q = self._quote
285            return 'ROW(%s)' % ','.join(str(q(v)) for v in val)
286        try:
287            return val.__pg_repr__()
288        except AttributeError:
289            raise InterfaceError(
290                'do not know how to handle type %s' % type(val))
291
292    def _quote_array(self, val):
293        """Quote value as a literal constant for an array."""
294        # We could also cast to an array constructor here, but that is more
295        # verbose and you need to know the base type to build emtpy arrays.
296        if isinstance(val, list):
297            return '{%s}' % ','.join(self._quote_array(v) for v in val)
298        if val is None:
299            return 'null'
300        if isinstance(val, (int, long, float)):
301            return str(val)
302        if isinstance(val, bool):
303            return 't' if val else 'f'
304        if isinstance(val, basestring):
305            if not val:
306                return '""'
307            if _re_array_quote.search(val):
308                return '"%s"' % _re_array_escape.sub(r'\\\1', val)
309            return val
310        try:
311            return val.__pg_repr__()
312        except AttributeError:
313            raise InterfaceError(
314                'do not know how to handle type %s' % type(val))
315
316    def _quoteparams(self, string, parameters):
317        """Quote parameters.
318
319        This function works for both mappings and sequences.
320        """
321        if isinstance(parameters, dict):
322            parameters = _quotedict(parameters)
323            parameters.quote = self._quote
324        else:
325            parameters = tuple(map(self._quote, parameters))
326        return string % parameters
327
328    def _make_description(self, info):
329        """Make the description tuple for the given field info."""
330        name, typ, size, mod = info[1:]
331        type_info = self._type_cache[typ]
332        type_code = type_info.name
333        if mod > 0:
334            mod -= 4
335        if type_code == 'numeric':
336            precision, scale = mod >> 16, mod & 0xffff
337            size = precision
338        else:
339            if not size:
340                size = type_info.size
341            if size == -1:
342                size = mod
343            precision = scale = None
344        return CursorDescription(name, type_code,
345            None, size, precision, scale, None)
346
347    @property
348    def description(self):
349        """Read-only attribute describing the result columns."""
350        descr = self._description
351        if self._description is True:
352            make = self._make_description
353            descr = [make(info) for info in self._src.listinfo()]
354            self._description = descr
355        return descr
356
357    @property
358    def colnames(self):
359        """Unofficial convenience method for getting the column names."""
360        return [d[0] for d in self.description]
361
362    @property
363    def coltypes(self):
364        """Unofficial convenience method for getting the column types."""
365        return [d[1] for d in self.description]
366
367    def close(self):
368        """Close the cursor object."""
369        self._src.close()
370        self._description = None
371        self.rowcount = -1
372        self.lastrowid = None
373
374    def execute(self, operation, parameters=None):
375        """Prepare and execute a database operation (query or command)."""
376        # The parameters may also be specified as list of tuples to e.g.
377        # insert multiple rows in a single operation, but this kind of
378        # usage is deprecated.  We make several plausibility checks because
379        # tuples can also be passed with the meaning of ROW constructors.
380        if (parameters and isinstance(parameters, list)
381                and len(parameters) > 1
382                and all(isinstance(p, tuple) for p in parameters)
383                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
384            return self.executemany(operation, parameters)
385        else:
386            # not a list of tuples
387            return self.executemany(operation, [parameters])
388
389    def executemany(self, operation, seq_of_parameters):
390        """Prepare operation and execute it against a parameter sequence."""
391        if not seq_of_parameters:
392            # don't do anything without parameters
393            return
394        self._description = None
395        self.rowcount = -1
396        # first try to execute all queries
397        rowcount = 0
398        sql = "BEGIN"
399        try:
400            if not self._dbcnx._tnx:
401                try:
402                    self._cnx.source().execute(sql)
403                except DatabaseError:
404                    raise  # database provides error message
405                except Exception as err:
406                    raise _op_error("can't start transaction")
407                self._dbcnx._tnx = True
408            for parameters in seq_of_parameters:
409                sql = operation
410                if parameters:
411                    sql = self._quoteparams(sql, parameters)
412                rows = self._src.execute(sql)
413                if rows:  # true if not DML
414                    rowcount += rows
415                else:
416                    self.rowcount = -1
417        except DatabaseError:
418            raise  # database provides error message
419        except Error as err:
420            raise _db_error(
421                "error in '%s': '%s' " % (sql, err), InterfaceError)
422        except Exception as err:
423            raise _op_error("internal error in '%s': %s" % (sql, err))
424        # then initialize result raw count and description
425        if self._src.resulttype == RESULT_DQL:
426            self._description = True  # fetch on demand
427            self.rowcount = self._src.ntuples
428            self.lastrowid = None
429            if self.build_row_factory:
430                self.row_factory = self.build_row_factory()
431        else:
432            self.rowcount = rowcount
433            self.lastrowid = self._src.oidstatus()
434        # return the cursor object, so you can write statements such as
435        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
436        return self
437
438    def fetchone(self):
439        """Fetch the next row of a query result set."""
440        res = self.fetchmany(1, False)
441        try:
442            return res[0]
443        except IndexError:
444            return None
445
446    def fetchall(self):
447        """Fetch all (remaining) rows of a query result."""
448        return self.fetchmany(-1, False)
449
450    def fetchmany(self, size=None, keep=False):
451        """Fetch the next set of rows of a query result.
452
453        The number of rows to fetch per call is specified by the
454        size parameter. If it is not given, the cursor's arraysize
455        determines the number of rows to be fetched. If you set
456        the keep parameter to true, this is kept as new arraysize.
457        """
458        if size is None:
459            size = self.arraysize
460        if keep:
461            self.arraysize = size
462        try:
463            result = self._src.fetch(size)
464        except DatabaseError:
465            raise
466        except Error as err:
467            raise _db_error(str(err))
468        typecast = self._type_cache.typecast
469        return [self.row_factory([typecast(typ, value)
470            for typ, value in zip(self.coltypes, row)]) for row in result]
471
472    def callproc(self, procname, parameters=None):
473        """Call a stored database procedure with the given name.
474
475        The sequence of parameters must contain one entry for each input
476        argument that the procedure expects. The result of the call is the
477        same as this input sequence; replacement of output and input/output
478        parameters in the return value is currently not supported.
479
480        The procedure may also provide a result set as output. These can be
481        requested through the standard fetch methods of the cursor.
482        """
483        n = parameters and len(parameters) or 0
484        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
485        self.execute(query, parameters)
486        return parameters
487
488    def copy_from(self, stream, table,
489            format=None, sep=None, null=None, size=None, columns=None):
490        """Copy data from an input stream to the specified table.
491
492        The input stream can be a file-like object with a read() method or
493        it can also be an iterable returning a row or multiple rows of input
494        on each iteration.
495
496        The format must be text, csv or binary. The sep option sets the
497        column separator (delimiter) used in the non binary formats.
498        The null option sets the textual representation of NULL in the input.
499
500        The size option sets the size of the buffer used when reading data
501        from file-like objects.
502
503        The copy operation can be restricted to a subset of columns. If no
504        columns are specified, all of them will be copied.
505        """
506        binary_format = format == 'binary'
507        try:
508            read = stream.read
509        except AttributeError:
510            if size:
511                raise ValueError("size must only be set for file-like objects")
512            if binary_format:
513                input_type = bytes
514                type_name = 'byte strings'
515            else:
516                input_type = basestring
517                type_name = 'strings'
518
519            if isinstance(stream, basestring):
520                if not isinstance(stream, input_type):
521                    raise ValueError("The input must be %s" % type_name)
522                if not binary_format:
523                    if isinstance(stream, str):
524                        if not stream.endswith('\n'):
525                            stream += '\n'
526                    else:
527                        if not stream.endswith(b'\n'):
528                            stream += b'\n'
529
530                def chunks():
531                    yield stream
532
533            elif isinstance(stream, Iterable):
534
535                def chunks():
536                    for chunk in stream:
537                        if not isinstance(chunk, input_type):
538                            raise ValueError(
539                                "Input stream must consist of %s" % type_name)
540                        if isinstance(chunk, str):
541                            if not chunk.endswith('\n'):
542                                chunk += '\n'
543                        else:
544                            if not chunk.endswith(b'\n'):
545                                chunk += b'\n'
546                        yield chunk
547
548            else:
549                raise TypeError("Need an input stream to copy from")
550        else:
551            if size is None:
552                size = 8192
553            elif not isinstance(size, int):
554                raise TypeError("The size option must be an integer")
555            if size > 0:
556
557                def chunks():
558                    while True:
559                        buffer = read(size)
560                        yield buffer
561                        if not buffer or len(buffer) < size:
562                            break
563
564            else:
565
566                def chunks():
567                    yield read()
568
569        if not table or not isinstance(table, basestring):
570            raise TypeError("Need a table to copy to")
571        if table.lower().startswith('select'):
572                raise ValueError("Must specify a table, not a query")
573        else:
574            table = '"%s"' % (table,)
575        operation = ['copy %s' % (table,)]
576        options = []
577        params = []
578        if format is not None:
579            if not isinstance(format, basestring):
580                raise TypeError("format option must be be a string")
581            if format not in ('text', 'csv', 'binary'):
582                raise ValueError("invalid format")
583            options.append('format %s' % (format,))
584        if sep is not None:
585            if not isinstance(sep, basestring):
586                raise TypeError("sep option must be a string")
587            if format == 'binary':
588                raise ValueError("sep is not allowed with binary format")
589            if len(sep) != 1:
590                raise ValueError("sep must be a single one-byte character")
591            options.append('delimiter %s')
592            params.append(sep)
593        if null is not None:
594            if not isinstance(null, basestring):
595                raise TypeError("null option must be a string")
596            options.append('null %s')
597            params.append(null)
598        if columns:
599            if not isinstance(columns, basestring):
600                columns = ','.join('"%s"' % (col,) for col in columns)
601            operation.append('(%s)' % (columns,))
602        operation.append("from stdin")
603        if options:
604            operation.append('(%s)' % ','.join(options))
605        operation = ' '.join(operation)
606
607        putdata = self._src.putdata
608        self.execute(operation, params)
609
610        try:
611            for chunk in chunks():
612                putdata(chunk)
613        except BaseException as error:
614            self.rowcount = -1
615            # the following call will re-raise the error
616            putdata(error)
617        else:
618            self.rowcount = putdata(None)
619
620        # return the cursor object, so you can chain operations
621        return self
622
623    def copy_to(self, stream, table,
624            format=None, sep=None, null=None, decode=None, columns=None):
625        """Copy data from the specified table to an output stream.
626
627        The output stream can be a file-like object with a write() method or
628        it can also be None, in which case the method will return a generator
629        yielding a row on each iteration.
630
631        Output will be returned as byte strings unless you set decode to true.
632
633        Note that you can also use a select query instead of the table name.
634
635        The format must be text, csv or binary. The sep option sets the
636        column separator (delimiter) used in the non binary formats.
637        The null option sets the textual representation of NULL in the output.
638
639        The copy operation can be restricted to a subset of columns. If no
640        columns are specified, all of them will be copied.
641        """
642        binary_format = format == 'binary'
643        if stream is not None:
644            try:
645                write = stream.write
646            except AttributeError:
647                raise TypeError("need an output stream to copy to")
648        if not table or not isinstance(table, basestring):
649            raise TypeError("need a table to copy to")
650        if table.lower().startswith('select'):
651            if columns:
652                raise ValueError("columns must be specified in the query")
653            table = '(%s)' % (table,)
654        else:
655            table = '"%s"' % (table,)
656        operation = ['copy %s' % (table,)]
657        options = []
658        params = []
659        if format is not None:
660            if not isinstance(format, basestring):
661                raise TypeError("format option must be a string")
662            if format not in ('text', 'csv', 'binary'):
663                raise ValueError("invalid format")
664            options.append('format %s' % (format,))
665        if sep is not None:
666            if not isinstance(sep, basestring):
667                raise TypeError("sep option must be a string")
668            if binary_format:
669                raise ValueError("sep is not allowed with binary format")
670            if len(sep) != 1:
671                raise ValueError("sep must be a single one-byte character")
672            options.append('delimiter %s')
673            params.append(sep)
674        if null is not None:
675            if not isinstance(null, basestring):
676                raise TypeError("null option must be a string")
677            options.append('null %s')
678            params.append(null)
679        if decode is None:
680            if format == 'binary':
681                decode = False
682            else:
683                decode = str is unicode
684        else:
685            if not isinstance(decode, (int, bool)):
686                raise TypeError("decode option must be a boolean")
687            if decode and binary_format:
688                raise ValueError("decode is not allowed with binary format")
689        if columns:
690            if not isinstance(columns, basestring):
691                columns = ','.join('"%s"' % (col,) for col in columns)
692            operation.append('(%s)' % (columns,))
693
694        operation.append("to stdout")
695        if options:
696            operation.append('(%s)' % ','.join(options))
697        operation = ' '.join(operation)
698
699        getdata = self._src.getdata
700        self.execute(operation, params)
701
702        def copy():
703            self.rowcount = 0
704            while True:
705                row = getdata(decode)
706                if isinstance(row, int):
707                    if self.rowcount != row:
708                        self.rowcount = row
709                    break
710                self.rowcount += 1
711                yield row
712
713        if stream is None:
714            # no input stream, return the generator
715            return copy()
716
717        # write the rows to the file-like input stream
718        for row in copy():
719            write(row)
720
721        # return the cursor object, so you can chain operations
722        return self
723
724    def __next__(self):
725        """Return the next row (support for the iteration protocol)."""
726        res = self.fetchone()
727        if res is None:
728            raise StopIteration
729        return res
730
731    # Note that since Python 2.6 the iterator protocol uses __next()__
732    # instead of next(), we keep it only for backward compatibility of pgdb.
733    next = __next__
734
735    @staticmethod
736    def nextset():
737        """Not supported."""
738        raise NotSupportedError("nextset() is not supported")
739
740    @staticmethod
741    def setinputsizes(sizes):
742        """Not supported."""
743        pass  # unsupported, but silently passed
744
745    @staticmethod
746    def setoutputsize(size, column=0):
747        """Not supported."""
748        pass  # unsupported, but silently passed
749
750    @staticmethod
751    def row_factory(row):
752        """Process rows before they are returned.
753
754        You can overwrite this statically with a custom row factory, or
755        you can build a row factory dynamically with build_row_factory().
756
757        For example, you can create a Cursor class that returns rows as
758        Python dictionaries like this:
759
760            class DictCursor(pgdb.Cursor):
761
762                def row_factory(self, row):
763                    return {desc[0]: value
764                        for desc, value in zip(self.description, row)}
765
766            cur = DictCursor(con)  # get one DictCursor instance or
767            con.cursor_type = DictCursor  # always use DictCursor instances
768        """
769        raise NotImplementedError
770
771    def build_row_factory(self):
772        """Build a row factory based on the current description.
773
774        This implementation builds a row factory for creating named tuples.
775        You can overwrite this method if you want to dynamically create
776        different row factories whenever the column description changes.
777        """
778        colnames = self.colnames
779        if colnames:
780            try:
781                try:
782                    return namedtuple('Row', colnames, rename=True)._make
783                except TypeError:  # Python 2.6 and 3.0 do not support rename
784                    colnames = [v if v.isalnum() else 'column_%d' % n
785                             for n, v in enumerate(colnames)]
786                    return namedtuple('Row', colnames)._make
787            except ValueError:  # there is still a problem with the field names
788                colnames = ['column_%d' % n for n in range(len(colnames))]
789                return namedtuple('Row', colnames)._make
790
791
792CursorDescription = namedtuple('CursorDescription',
793    ['name', 'type_code', 'display_size', 'internal_size',
794     'precision', 'scale', 'null_ok'])
795
796
797### Connection Objects
798
799class Connection(object):
800    """Connection object."""
801
802    # expose the exceptions as attributes on the connection object
803    Error = Error
804    Warning = Warning
805    InterfaceError = InterfaceError
806    DatabaseError = DatabaseError
807    InternalError = InternalError
808    OperationalError = OperationalError
809    ProgrammingError = ProgrammingError
810    IntegrityError = IntegrityError
811    DataError = DataError
812    NotSupportedError = NotSupportedError
813
814    def __init__(self, cnx):
815        """Create a database connection object."""
816        self._cnx = cnx  # connection
817        self._tnx = False  # transaction state
818        self._type_cache = TypeCache(cnx)
819        self.cursor_type = Cursor
820        try:
821            self._cnx.source()
822        except Exception:
823            raise _op_error("invalid connection")
824
825    def __enter__(self):
826        """Enter the runtime context for the connection object.
827
828        The runtime context can be used for running transactions.
829        """
830        return self
831
832    def __exit__(self, et, ev, tb):
833        """Exit the runtime context for the connection object.
834
835        This does not close the connection, but it ends a transaction.
836        """
837        if et is None and ev is None and tb is None:
838            self.commit()
839        else:
840            self.rollback()
841
842    def close(self):
843        """Close the connection object."""
844        if self._cnx:
845            if self._tnx:
846                try:
847                    self.rollback()
848                except DatabaseError:
849                    pass
850            self._cnx.close()
851            self._cnx = None
852        else:
853            raise _op_error("connection has been closed")
854
855    def commit(self):
856        """Commit any pending transaction to the database."""
857        if self._cnx:
858            if self._tnx:
859                self._tnx = False
860                try:
861                    self._cnx.source().execute("COMMIT")
862                except DatabaseError:
863                    raise
864                except Exception:
865                    raise _op_error("can't commit")
866        else:
867            raise _op_error("connection has been closed")
868
869    def rollback(self):
870        """Roll back to the start of any pending transaction."""
871        if self._cnx:
872            if self._tnx:
873                self._tnx = False
874                try:
875                    self._cnx.source().execute("ROLLBACK")
876                except DatabaseError:
877                    raise
878                except Exception:
879                    raise _op_error("can't rollback")
880        else:
881            raise _op_error("connection has been closed")
882
883    def cursor(self):
884        """Return a new cursor object using the connection."""
885        if self._cnx:
886            try:
887                return self.cursor_type(self)
888            except Exception:
889                raise _op_error("invalid connection")
890        else:
891            raise _op_error("connection has been closed")
892
893    if shortcutmethods:  # otherwise do not implement and document this
894
895        def execute(self, operation, params=None):
896            """Shortcut method to run an operation on an implicit cursor."""
897            cursor = self.cursor()
898            cursor.execute(operation, params)
899            return cursor
900
901        def executemany(self, operation, param_seq):
902            """Shortcut method to run an operation against a sequence."""
903            cursor = self.cursor()
904            cursor.executemany(operation, param_seq)
905            return cursor
906
907
908### Module Interface
909
910_connect_ = connect
911
912def connect(dsn=None,
913        user=None, password=None,
914        host=None, database=None):
915    """Connect to a database."""
916    # first get params from DSN
917    dbport = -1
918    dbhost = ""
919    dbbase = ""
920    dbuser = ""
921    dbpasswd = ""
922    dbopt = ""
923    try:
924        params = dsn.split(":")
925        dbhost = params[0]
926        dbbase = params[1]
927        dbuser = params[2]
928        dbpasswd = params[3]
929        dbopt = params[4]
930    except (AttributeError, IndexError, TypeError):
931        pass
932
933    # override if necessary
934    if user is not None:
935        dbuser = user
936    if password is not None:
937        dbpasswd = password
938    if database is not None:
939        dbbase = database
940    if host is not None:
941        try:
942            params = host.split(":")
943            dbhost = params[0]
944            dbport = int(params[1])
945        except (AttributeError, IndexError, TypeError, ValueError):
946            pass
947
948    # empty host is localhost
949    if dbhost == "":
950        dbhost = None
951    if dbuser == "":
952        dbuser = None
953
954    # open the connection
955    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
956    return Connection(cnx)
957
958
959### Types Handling
960
961class Type(frozenset):
962    """Type class for a couple of PostgreSQL data types.
963
964    PostgreSQL is object-oriented: types are dynamic.
965    We must thus use type names as internal type codes.
966    """
967
968    def __new__(cls, values):
969        if isinstance(values, basestring):
970            values = values.split()
971        return super(Type, cls).__new__(cls, values)
972
973    def __eq__(self, other):
974        if isinstance(other, basestring):
975            if other.startswith('_'):
976                other = other[1:]
977            return other in self
978        else:
979            return super(Type, self).__eq__(other)
980
981    def __ne__(self, other):
982        if isinstance(other, basestring):
983            if other.startswith('_'):
984                other = other[1:]
985            return other not in self
986        else:
987            return super(Type, self).__ne__(other)
988
989
990class ArrayType:
991    """Type class for PostgreSQL array types."""
992
993    def __eq__(self, other):
994        if isinstance(other, basestring):
995            return other.startswith('_')
996        else:
997            return isinstance(other, ArrayType)
998
999    def __ne__(self, other):
1000        if isinstance(other, basestring):
1001            return not other.startswith('_')
1002        else:
1003            return not isinstance(other, ArrayType)
1004
1005
1006# Mandatory type objects defined by DB-API 2 specs:
1007
1008STRING = Type('char bpchar name text varchar')
1009BINARY = Type('bytea')
1010NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1011DATETIME = Type('date time timetz timestamp timestamptz interval'
1012    ' abstime reltime')  # these are very old
1013ROWID = Type('oid')
1014
1015
1016# Additional type objects (more specific):
1017
1018BOOL = Type('bool')
1019SMALLINT = Type('int2')
1020INTEGER = Type('int2 int4 int8 serial')
1021LONG = Type('int8')
1022FLOAT = Type('float4 float8')
1023NUMERIC = Type('numeric')
1024MONEY = Type('money')
1025DATE = Type('date')
1026TIME = Type('time timetz')
1027TIMESTAMP = Type('timestamp timestamptz')
1028INTERVAL = Type('interval')
1029JSON = Type('json jsonb')
1030
1031# Type object for arrays (also equate to their base types):
1032
1033ARRAY = ArrayType()
1034
1035
1036# Mandatory type helpers defined by DB-API 2 specs:
1037
1038def Date(year, month, day):
1039    """Construct an object holding a date value."""
1040    return date(year, month, day)
1041
1042
1043def Time(hour, minute=0, second=0, microsecond=0):
1044    """Construct an object holding a time value."""
1045    return time(hour, minute, second, microsecond)
1046
1047
1048def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1049    """Construct an object holding a time stamp value."""
1050    return datetime(year, month, day, hour, minute, second, microsecond)
1051
1052
1053def DateFromTicks(ticks):
1054    """Construct an object holding a date value from the given ticks value."""
1055    return Date(*localtime(ticks)[:3])
1056
1057
1058def TimeFromTicks(ticks):
1059    """Construct an object holding a time value from the given ticks value."""
1060    return Time(*localtime(ticks)[3:6])
1061
1062
1063def TimestampFromTicks(ticks):
1064    """Construct an object holding a time stamp from the given ticks value."""
1065    return Timestamp(*localtime(ticks)[:6])
1066
1067
1068class Binary(bytes):
1069    """Construct an object capable of holding a binary (long) string value."""
1070
1071
1072# Additional type helpers for PyGreSQL:
1073
1074class Json:
1075    """Construct a wrapper for holding an object serializable to JSON."""
1076
1077    def __init__(self, obj, encode=None):
1078        self.obj = obj
1079        self.encode = encode or jsonencode
1080
1081    def __str__(self):
1082        obj = self.obj
1083        if isinstance(obj, basestring):
1084            return obj
1085        return self.encode(obj)
1086
1087    __pg_repr__ = __str__
1088
1089
1090# If run as script, print some information:
1091
1092if __name__ == '__main__':
1093    print('PyGreSQL version', version)
1094    print('')
1095    print(__doc__)
Note: See TracBrowser for help on using the repository browser.