source: trunk/pgdb.py @ 787

Last change on this file since 787 was 787, checked in by cito, 4 years ago

Make the type cache of pgdb available to users

  • Property svn:keywords set to Id
File size: 35.4 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 787 2016-01-26 18:59:18Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from re import compile as regex
76from json import loads as jsondecode, dumps as jsonencode
77
78try:
79    long
80except NameError:  # Python >= 3.0
81    long = int
82
83try:
84    unicode
85except NameError:  # Python >= 3.0
86    unicode = str
87
88try:
89    basestring
90except NameError:  # Python >= 3.0
91    basestring = (str, bytes)
92
93from collections import Iterable
94try:
95    from collections import OrderedDict
96except ImportError:  # Python 2.6 or 3.0
97    try:
98        from ordereddict import OrderedDict
99    except Exception:
100        def OrderedDict(*args):
101            raise NotSupportedError('OrderedDict is not supported')
102
103
104### Module Constants
105
106# compliant with DB API 2.0
107apilevel = '2.0'
108
109# module may be shared, but not connections
110threadsafety = 1
111
112# this module use extended python format codes
113paramstyle = 'pyformat'
114
115# shortcut methods have been excluded from DB API 2 and
116# are not recommended by the DB SIG, but they can be handy
117shortcutmethods = 1
118
119
120### Internal Types Handling
121
122def decimal_type(decimal_type=None):
123    """Get or set global type to be used for decimal values."""
124    global Decimal
125    if decimal_type is not None:
126        _cast['numeric'] = Decimal = decimal_type
127    return Decimal
128
129
130def _cast_bool(value):
131    return value[:1] in ('t', 'T')
132
133
134def _cast_money(value):
135    return Decimal(''.join(filter(
136        lambda v: v in '0123456789.-', value)))
137
138
139_cast = {'bool': _cast_bool, 'bytea': unescape_bytea,
140    'int2': int, 'int4': int, 'serial': int,
141    'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
142    'oid': long, 'oid8': long,
143    'float4': float, 'float8': float,
144    'numeric': Decimal, 'money': _cast_money}
145
146
147def _db_error(msg, cls=DatabaseError):
148    """Return DatabaseError with empty sqlstate attribute."""
149    error = cls(msg)
150    error.sqlstate = None
151    return error
152
153
154def _op_error(msg):
155    """Return OperationalError."""
156    return _db_error(msg, OperationalError)
157
158
159TypeInfo = namedtuple('TypeInfo',
160    ['oid', 'name', 'len', 'type', 'category', 'delim', 'relid'])
161
162
163class TypeCache(dict):
164    """Cache for database types.
165
166    This cache maps type OIDs and names to TypeInfo tuples containing
167    important information on the associated database type.
168    """
169
170    def __init__(self, cnx):
171        """Initialize type cache for connection."""
172        super(TypeCache, self).__init__()
173        self._escape_string = cnx.escape_string
174        self._src = cnx.source()
175
176    def __missing__(self, key):
177        q = ("SELECT oid, typname,"
178             " typlen, typtype, typcategory, typdelim, typrelid"
179            " FROM pg_type WHERE ")
180        if isinstance(key, int):
181            q += "oid = %d" % key
182        else:
183            q += "typname = '%s'" % self._escape_string(key)
184        self._src.execute(q)
185        res = self._src.fetch(1)
186        if not res:
187            raise KeyError('Type %s could not be found' % key)
188        res = list(res[0])
189        res[0] = int(res[0])
190        res[2] = int(res[2])
191        res[6] = int(res[6])
192        res = TypeInfo(*res)
193        self[res.oid] = self[res.name] = res
194        return res
195
196    @staticmethod
197    def typecast(typ, value):
198        """Cast value according to database type."""
199        if value is None:
200            # for NULL values, no typecast is necessary
201            return None
202        cast = _cast.get(typ)
203        if cast is None:
204            if typ.startswith('_'):
205                # cast as an array type
206                cast = _cast.get(typ[1:])
207                return cast_array(value, cast)
208            # no typecast available or necessary
209            return value
210        else:
211            return cast(value)
212
213
214_re_array_escape = regex(r'(["\\])')
215_re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
216
217
218class _quotedict(dict):
219    """Dictionary with auto quoting of its items.
220
221    The quote attribute must be set to the desired quote function.
222    """
223
224    def __getitem__(self, key):
225        return self.quote(super(_quotedict, self).__getitem__(key))
226
227
228### Cursor Object
229
230class Cursor(object):
231    """Cursor object."""
232
233    def __init__(self, dbcnx):
234        """Create a cursor object for the database connection."""
235        self.connection = self._dbcnx = dbcnx
236        self._cnx = dbcnx._cnx
237        self.type_cache = dbcnx.type_cache
238        self._src = self._cnx.source()
239        # the official attribute for describing the result columns
240        self._description = None
241        if self.row_factory is Cursor.row_factory:
242            # the row factory needs to be determined dynamically
243            self.row_factory = None
244        else:
245            self.build_row_factory = None
246        self.rowcount = -1
247        self.arraysize = 1
248        self.lastrowid = None
249
250    def __iter__(self):
251        """Make cursor compatible to the iteration protocol."""
252        return self
253
254    def __enter__(self):
255        """Enter the runtime context for the cursor object."""
256        return self
257
258    def __exit__(self, et, ev, tb):
259        """Exit the runtime context for the cursor object."""
260        self.close()
261
262    def _quote(self, val):
263        """Quote value depending on its type."""
264        if val is None:
265            return 'NULL'
266        if isinstance(val, (datetime, date, time, timedelta, Json)):
267            val = str(val)
268        if isinstance(val, basestring):
269            if isinstance(val, Binary):
270                val = self._cnx.escape_bytea(val)
271                if bytes is not str:  # Python >= 3.0
272                    val = val.decode('ascii')
273            else:
274                val = self._cnx.escape_string(val)
275            return "'%s'" % val
276        if isinstance(val, float):
277            if isinf(val):
278                return "'-Infinity'" if val < 0 else "'Infinity'"
279            if isnan(val):
280                return "'NaN'"
281            return val
282        if isinstance(val, (int, long, Decimal)):
283            return val
284        if isinstance(val, list):
285            return "'%s'" % self._quote_array(val)
286        if isinstance(val, tuple):
287            q = self._quote
288            return 'ROW(%s)' % ','.join(str(q(v)) for v in val)
289        try:
290            return val.__pg_repr__()
291        except AttributeError:
292            raise InterfaceError(
293                'do not know how to handle type %s' % type(val))
294
295    def _quote_array(self, val):
296        """Quote value as a literal constant for an array."""
297        # We could also cast to an array constructor here, but that is more
298        # verbose and you need to know the base type to build emtpy arrays.
299        if isinstance(val, list):
300            return '{%s}' % ','.join(self._quote_array(v) for v in val)
301        if val is None:
302            return 'null'
303        if isinstance(val, (int, long, float)):
304            return str(val)
305        if isinstance(val, bool):
306            return 't' if val else 'f'
307        if isinstance(val, basestring):
308            if not val:
309                return '""'
310            if _re_array_quote.search(val):
311                return '"%s"' % _re_array_escape.sub(r'\\\1', val)
312            return val
313        try:
314            return val.__pg_repr__()
315        except AttributeError:
316            raise InterfaceError(
317                'do not know how to handle type %s' % type(val))
318
319    def _quoteparams(self, string, parameters):
320        """Quote parameters.
321
322        This function works for both mappings and sequences.
323        """
324        if isinstance(parameters, dict):
325            parameters = _quotedict(parameters)
326            parameters.quote = self._quote
327        else:
328            parameters = tuple(map(self._quote, parameters))
329        return string % parameters
330
331    def _make_description(self, info):
332        """Make the description tuple for the given field info."""
333        name, typ, size, mod = info[1:]
334        type_info = self.type_cache[typ]
335        type_code = type_info.name
336        if mod > 0:
337            mod -= 4
338        if type_code == 'numeric':
339            precision, scale = mod >> 16, mod & 0xffff
340            size = precision
341        else:
342            if not size:
343                size = type_info.size
344            if size == -1:
345                size = mod
346            precision = scale = None
347        return CursorDescription(name, type_code,
348            None, size, precision, scale, None)
349
350    @property
351    def description(self):
352        """Read-only attribute describing the result columns."""
353        descr = self._description
354        if self._description is True:
355            make = self._make_description
356            descr = [make(info) for info in self._src.listinfo()]
357            self._description = descr
358        return descr
359
360    @property
361    def colnames(self):
362        """Unofficial convenience method for getting the column names."""
363        return [d[0] for d in self.description]
364
365    @property
366    def coltypes(self):
367        """Unofficial convenience method for getting the column types."""
368        return [d[1] for d in self.description]
369
370    def close(self):
371        """Close the cursor object."""
372        self._src.close()
373        self._description = None
374        self.rowcount = -1
375        self.lastrowid = None
376
377    def execute(self, operation, parameters=None):
378        """Prepare and execute a database operation (query or command)."""
379        # The parameters may also be specified as list of tuples to e.g.
380        # insert multiple rows in a single operation, but this kind of
381        # usage is deprecated.  We make several plausibility checks because
382        # tuples can also be passed with the meaning of ROW constructors.
383        if (parameters and isinstance(parameters, list)
384                and len(parameters) > 1
385                and all(isinstance(p, tuple) for p in parameters)
386                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
387            return self.executemany(operation, parameters)
388        else:
389            # not a list of tuples
390            return self.executemany(operation, [parameters])
391
392    def executemany(self, operation, seq_of_parameters):
393        """Prepare operation and execute it against a parameter sequence."""
394        if not seq_of_parameters:
395            # don't do anything without parameters
396            return
397        self._description = None
398        self.rowcount = -1
399        # first try to execute all queries
400        rowcount = 0
401        sql = "BEGIN"
402        try:
403            if not self._dbcnx._tnx:
404                try:
405                    self._cnx.source().execute(sql)
406                except DatabaseError:
407                    raise  # database provides error message
408                except Exception as err:
409                    raise _op_error("can't start transaction")
410                self._dbcnx._tnx = True
411            for parameters in seq_of_parameters:
412                sql = operation
413                if parameters:
414                    sql = self._quoteparams(sql, parameters)
415                rows = self._src.execute(sql)
416                if rows:  # true if not DML
417                    rowcount += rows
418                else:
419                    self.rowcount = -1
420        except DatabaseError:
421            raise  # database provides error message
422        except Error as err:
423            raise _db_error(
424                "error in '%s': '%s' " % (sql, err), InterfaceError)
425        except Exception as err:
426            raise _op_error("internal error in '%s': %s" % (sql, err))
427        # then initialize result raw count and description
428        if self._src.resulttype == RESULT_DQL:
429            self._description = True  # fetch on demand
430            self.rowcount = self._src.ntuples
431            self.lastrowid = None
432            if self.build_row_factory:
433                self.row_factory = self.build_row_factory()
434        else:
435            self.rowcount = rowcount
436            self.lastrowid = self._src.oidstatus()
437        # return the cursor object, so you can write statements such as
438        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
439        return self
440
441    def fetchone(self):
442        """Fetch the next row of a query result set."""
443        res = self.fetchmany(1, False)
444        try:
445            return res[0]
446        except IndexError:
447            return None
448
449    def fetchall(self):
450        """Fetch all (remaining) rows of a query result."""
451        return self.fetchmany(-1, False)
452
453    def fetchmany(self, size=None, keep=False):
454        """Fetch the next set of rows of a query result.
455
456        The number of rows to fetch per call is specified by the
457        size parameter. If it is not given, the cursor's arraysize
458        determines the number of rows to be fetched. If you set
459        the keep parameter to true, this is kept as new arraysize.
460        """
461        if size is None:
462            size = self.arraysize
463        if keep:
464            self.arraysize = size
465        try:
466            result = self._src.fetch(size)
467        except DatabaseError:
468            raise
469        except Error as err:
470            raise _db_error(str(err))
471        typecast = self.type_cache.typecast
472        return [self.row_factory([typecast(typ, value)
473            for typ, value in zip(self.coltypes, row)]) for row in result]
474
475    def callproc(self, procname, parameters=None):
476        """Call a stored database procedure with the given name.
477
478        The sequence of parameters must contain one entry for each input
479        argument that the procedure expects. The result of the call is the
480        same as this input sequence; replacement of output and input/output
481        parameters in the return value is currently not supported.
482
483        The procedure may also provide a result set as output. These can be
484        requested through the standard fetch methods of the cursor.
485        """
486        n = parameters and len(parameters) or 0
487        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
488        self.execute(query, parameters)
489        return parameters
490
491    def copy_from(self, stream, table,
492            format=None, sep=None, null=None, size=None, columns=None):
493        """Copy data from an input stream to the specified table.
494
495        The input stream can be a file-like object with a read() method or
496        it can also be an iterable returning a row or multiple rows of input
497        on each iteration.
498
499        The format must be text, csv or binary. The sep option sets the
500        column separator (delimiter) used in the non binary formats.
501        The null option sets the textual representation of NULL in the input.
502
503        The size option sets the size of the buffer used when reading data
504        from file-like objects.
505
506        The copy operation can be restricted to a subset of columns. If no
507        columns are specified, all of them will be copied.
508        """
509        binary_format = format == 'binary'
510        try:
511            read = stream.read
512        except AttributeError:
513            if size:
514                raise ValueError("size must only be set for file-like objects")
515            if binary_format:
516                input_type = bytes
517                type_name = 'byte strings'
518            else:
519                input_type = basestring
520                type_name = 'strings'
521
522            if isinstance(stream, basestring):
523                if not isinstance(stream, input_type):
524                    raise ValueError("The input must be %s" % type_name)
525                if not binary_format:
526                    if isinstance(stream, str):
527                        if not stream.endswith('\n'):
528                            stream += '\n'
529                    else:
530                        if not stream.endswith(b'\n'):
531                            stream += b'\n'
532
533                def chunks():
534                    yield stream
535
536            elif isinstance(stream, Iterable):
537
538                def chunks():
539                    for chunk in stream:
540                        if not isinstance(chunk, input_type):
541                            raise ValueError(
542                                "Input stream must consist of %s" % type_name)
543                        if isinstance(chunk, str):
544                            if not chunk.endswith('\n'):
545                                chunk += '\n'
546                        else:
547                            if not chunk.endswith(b'\n'):
548                                chunk += b'\n'
549                        yield chunk
550
551            else:
552                raise TypeError("Need an input stream to copy from")
553        else:
554            if size is None:
555                size = 8192
556            elif not isinstance(size, int):
557                raise TypeError("The size option must be an integer")
558            if size > 0:
559
560                def chunks():
561                    while True:
562                        buffer = read(size)
563                        yield buffer
564                        if not buffer or len(buffer) < size:
565                            break
566
567            else:
568
569                def chunks():
570                    yield read()
571
572        if not table or not isinstance(table, basestring):
573            raise TypeError("Need a table to copy to")
574        if table.lower().startswith('select'):
575                raise ValueError("Must specify a table, not a query")
576        else:
577            table = '"%s"' % (table,)
578        operation = ['copy %s' % (table,)]
579        options = []
580        params = []
581        if format is not None:
582            if not isinstance(format, basestring):
583                raise TypeError("format option must be be a string")
584            if format not in ('text', 'csv', 'binary'):
585                raise ValueError("invalid format")
586            options.append('format %s' % (format,))
587        if sep is not None:
588            if not isinstance(sep, basestring):
589                raise TypeError("sep option must be a string")
590            if format == 'binary':
591                raise ValueError("sep is not allowed with binary format")
592            if len(sep) != 1:
593                raise ValueError("sep must be a single one-byte character")
594            options.append('delimiter %s')
595            params.append(sep)
596        if null is not None:
597            if not isinstance(null, basestring):
598                raise TypeError("null option must be a string")
599            options.append('null %s')
600            params.append(null)
601        if columns:
602            if not isinstance(columns, basestring):
603                columns = ','.join('"%s"' % (col,) for col in columns)
604            operation.append('(%s)' % (columns,))
605        operation.append("from stdin")
606        if options:
607            operation.append('(%s)' % ','.join(options))
608        operation = ' '.join(operation)
609
610        putdata = self._src.putdata
611        self.execute(operation, params)
612
613        try:
614            for chunk in chunks():
615                putdata(chunk)
616        except BaseException as error:
617            self.rowcount = -1
618            # the following call will re-raise the error
619            putdata(error)
620        else:
621            self.rowcount = putdata(None)
622
623        # return the cursor object, so you can chain operations
624        return self
625
626    def copy_to(self, stream, table,
627            format=None, sep=None, null=None, decode=None, columns=None):
628        """Copy data from the specified table to an output stream.
629
630        The output stream can be a file-like object with a write() method or
631        it can also be None, in which case the method will return a generator
632        yielding a row on each iteration.
633
634        Output will be returned as byte strings unless you set decode to true.
635
636        Note that you can also use a select query instead of the table name.
637
638        The format must be text, csv or binary. The sep option sets the
639        column separator (delimiter) used in the non binary formats.
640        The null option sets the textual representation of NULL in the output.
641
642        The copy operation can be restricted to a subset of columns. If no
643        columns are specified, all of them will be copied.
644        """
645        binary_format = format == 'binary'
646        if stream is not None:
647            try:
648                write = stream.write
649            except AttributeError:
650                raise TypeError("need an output stream to copy to")
651        if not table or not isinstance(table, basestring):
652            raise TypeError("need a table to copy to")
653        if table.lower().startswith('select'):
654            if columns:
655                raise ValueError("columns must be specified in the query")
656            table = '(%s)' % (table,)
657        else:
658            table = '"%s"' % (table,)
659        operation = ['copy %s' % (table,)]
660        options = []
661        params = []
662        if format is not None:
663            if not isinstance(format, basestring):
664                raise TypeError("format option must be a string")
665            if format not in ('text', 'csv', 'binary'):
666                raise ValueError("invalid format")
667            options.append('format %s' % (format,))
668        if sep is not None:
669            if not isinstance(sep, basestring):
670                raise TypeError("sep option must be a string")
671            if binary_format:
672                raise ValueError("sep is not allowed with binary format")
673            if len(sep) != 1:
674                raise ValueError("sep must be a single one-byte character")
675            options.append('delimiter %s')
676            params.append(sep)
677        if null is not None:
678            if not isinstance(null, basestring):
679                raise TypeError("null option must be a string")
680            options.append('null %s')
681            params.append(null)
682        if decode is None:
683            if format == 'binary':
684                decode = False
685            else:
686                decode = str is unicode
687        else:
688            if not isinstance(decode, (int, bool)):
689                raise TypeError("decode option must be a boolean")
690            if decode and binary_format:
691                raise ValueError("decode is not allowed with binary format")
692        if columns:
693            if not isinstance(columns, basestring):
694                columns = ','.join('"%s"' % (col,) for col in columns)
695            operation.append('(%s)' % (columns,))
696
697        operation.append("to stdout")
698        if options:
699            operation.append('(%s)' % ','.join(options))
700        operation = ' '.join(operation)
701
702        getdata = self._src.getdata
703        self.execute(operation, params)
704
705        def copy():
706            self.rowcount = 0
707            while True:
708                row = getdata(decode)
709                if isinstance(row, int):
710                    if self.rowcount != row:
711                        self.rowcount = row
712                    break
713                self.rowcount += 1
714                yield row
715
716        if stream is None:
717            # no input stream, return the generator
718            return copy()
719
720        # write the rows to the file-like input stream
721        for row in copy():
722            write(row)
723
724        # return the cursor object, so you can chain operations
725        return self
726
727    def __next__(self):
728        """Return the next row (support for the iteration protocol)."""
729        res = self.fetchone()
730        if res is None:
731            raise StopIteration
732        return res
733
734    # Note that since Python 2.6 the iterator protocol uses __next()__
735    # instead of next(), we keep it only for backward compatibility of pgdb.
736    next = __next__
737
738    @staticmethod
739    def nextset():
740        """Not supported."""
741        raise NotSupportedError("nextset() is not supported")
742
743    @staticmethod
744    def setinputsizes(sizes):
745        """Not supported."""
746        pass  # unsupported, but silently passed
747
748    @staticmethod
749    def setoutputsize(size, column=0):
750        """Not supported."""
751        pass  # unsupported, but silently passed
752
753    @staticmethod
754    def row_factory(row):
755        """Process rows before they are returned.
756
757        You can overwrite this statically with a custom row factory, or
758        you can build a row factory dynamically with build_row_factory().
759
760        For example, you can create a Cursor class that returns rows as
761        Python dictionaries like this:
762
763            class DictCursor(pgdb.Cursor):
764
765                def row_factory(self, row):
766                    return {desc[0]: value
767                        for desc, value in zip(self.description, row)}
768
769            cur = DictCursor(con)  # get one DictCursor instance or
770            con.cursor_type = DictCursor  # always use DictCursor instances
771        """
772        raise NotImplementedError
773
774    def build_row_factory(self):
775        """Build a row factory based on the current description.
776
777        This implementation builds a row factory for creating named tuples.
778        You can overwrite this method if you want to dynamically create
779        different row factories whenever the column description changes.
780        """
781        colnames = self.colnames
782        if colnames:
783            try:
784                try:
785                    return namedtuple('Row', colnames, rename=True)._make
786                except TypeError:  # Python 2.6 and 3.0 do not support rename
787                    colnames = [v if v.isalnum() else 'column_%d' % n
788                             for n, v in enumerate(colnames)]
789                    return namedtuple('Row', colnames)._make
790            except ValueError:  # there is still a problem with the field names
791                colnames = ['column_%d' % n for n in range(len(colnames))]
792                return namedtuple('Row', colnames)._make
793
794
795CursorDescription = namedtuple('CursorDescription',
796    ['name', 'type_code', 'display_size', 'internal_size',
797     'precision', 'scale', 'null_ok'])
798
799
800### Connection Objects
801
802class Connection(object):
803    """Connection object."""
804
805    # expose the exceptions as attributes on the connection object
806    Error = Error
807    Warning = Warning
808    InterfaceError = InterfaceError
809    DatabaseError = DatabaseError
810    InternalError = InternalError
811    OperationalError = OperationalError
812    ProgrammingError = ProgrammingError
813    IntegrityError = IntegrityError
814    DataError = DataError
815    NotSupportedError = NotSupportedError
816
817    def __init__(self, cnx):
818        """Create a database connection object."""
819        self._cnx = cnx  # connection
820        self._tnx = False  # transaction state
821        self.type_cache = TypeCache(cnx)
822        self.cursor_type = Cursor
823        try:
824            self._cnx.source()
825        except Exception:
826            raise _op_error("invalid connection")
827
828    def __enter__(self):
829        """Enter the runtime context for the connection object.
830
831        The runtime context can be used for running transactions.
832        """
833        return self
834
835    def __exit__(self, et, ev, tb):
836        """Exit the runtime context for the connection object.
837
838        This does not close the connection, but it ends a transaction.
839        """
840        if et is None and ev is None and tb is None:
841            self.commit()
842        else:
843            self.rollback()
844
845    def close(self):
846        """Close the connection object."""
847        if self._cnx:
848            if self._tnx:
849                try:
850                    self.rollback()
851                except DatabaseError:
852                    pass
853            self._cnx.close()
854            self._cnx = None
855        else:
856            raise _op_error("connection has been closed")
857
858    def commit(self):
859        """Commit any pending transaction to the database."""
860        if self._cnx:
861            if self._tnx:
862                self._tnx = False
863                try:
864                    self._cnx.source().execute("COMMIT")
865                except DatabaseError:
866                    raise
867                except Exception:
868                    raise _op_error("can't commit")
869        else:
870            raise _op_error("connection has been closed")
871
872    def rollback(self):
873        """Roll back to the start of any pending transaction."""
874        if self._cnx:
875            if self._tnx:
876                self._tnx = False
877                try:
878                    self._cnx.source().execute("ROLLBACK")
879                except DatabaseError:
880                    raise
881                except Exception:
882                    raise _op_error("can't rollback")
883        else:
884            raise _op_error("connection has been closed")
885
886    def cursor(self):
887        """Return a new cursor object using the connection."""
888        if self._cnx:
889            try:
890                return self.cursor_type(self)
891            except Exception:
892                raise _op_error("invalid connection")
893        else:
894            raise _op_error("connection has been closed")
895
896    if shortcutmethods:  # otherwise do not implement and document this
897
898        def execute(self, operation, params=None):
899            """Shortcut method to run an operation on an implicit cursor."""
900            cursor = self.cursor()
901            cursor.execute(operation, params)
902            return cursor
903
904        def executemany(self, operation, param_seq):
905            """Shortcut method to run an operation against a sequence."""
906            cursor = self.cursor()
907            cursor.executemany(operation, param_seq)
908            return cursor
909
910
911### Module Interface
912
913_connect_ = connect
914
915def connect(dsn=None,
916        user=None, password=None,
917        host=None, database=None):
918    """Connect to a database."""
919    # first get params from DSN
920    dbport = -1
921    dbhost = ""
922    dbbase = ""
923    dbuser = ""
924    dbpasswd = ""
925    dbopt = ""
926    try:
927        params = dsn.split(":")
928        dbhost = params[0]
929        dbbase = params[1]
930        dbuser = params[2]
931        dbpasswd = params[3]
932        dbopt = params[4]
933    except (AttributeError, IndexError, TypeError):
934        pass
935
936    # override if necessary
937    if user is not None:
938        dbuser = user
939    if password is not None:
940        dbpasswd = password
941    if database is not None:
942        dbbase = database
943    if host is not None:
944        try:
945            params = host.split(":")
946            dbhost = params[0]
947            dbport = int(params[1])
948        except (AttributeError, IndexError, TypeError, ValueError):
949            pass
950
951    # empty host is localhost
952    if dbhost == "":
953        dbhost = None
954    if dbuser == "":
955        dbuser = None
956
957    # open the connection
958    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
959    return Connection(cnx)
960
961
962### Types Handling
963
964class Type(frozenset):
965    """Type class for a couple of PostgreSQL data types.
966
967    PostgreSQL is object-oriented: types are dynamic.
968    We must thus use type names as internal type codes.
969    """
970
971    def __new__(cls, values):
972        if isinstance(values, basestring):
973            values = values.split()
974        return super(Type, cls).__new__(cls, values)
975
976    def __eq__(self, other):
977        if isinstance(other, basestring):
978            if other.startswith('_'):
979                other = other[1:]
980            return other in self
981        else:
982            return super(Type, self).__eq__(other)
983
984    def __ne__(self, other):
985        if isinstance(other, basestring):
986            if other.startswith('_'):
987                other = other[1:]
988            return other not in self
989        else:
990            return super(Type, self).__ne__(other)
991
992
993class ArrayType:
994    """Type class for PostgreSQL array types."""
995
996    def __eq__(self, other):
997        if isinstance(other, basestring):
998            return other.startswith('_')
999        else:
1000            return isinstance(other, ArrayType)
1001
1002    def __ne__(self, other):
1003        if isinstance(other, basestring):
1004            return not other.startswith('_')
1005        else:
1006            return not isinstance(other, ArrayType)
1007
1008
1009# Mandatory type objects defined by DB-API 2 specs:
1010
1011STRING = Type('char bpchar name text varchar')
1012BINARY = Type('bytea')
1013NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1014DATETIME = Type('date time timetz timestamp timestamptz interval'
1015    ' abstime reltime')  # these are very old
1016ROWID = Type('oid')
1017
1018
1019# Additional type objects (more specific):
1020
1021BOOL = Type('bool')
1022SMALLINT = Type('int2')
1023INTEGER = Type('int2 int4 int8 serial')
1024LONG = Type('int8')
1025FLOAT = Type('float4 float8')
1026NUMERIC = Type('numeric')
1027MONEY = Type('money')
1028DATE = Type('date')
1029TIME = Type('time timetz')
1030TIMESTAMP = Type('timestamp timestamptz')
1031INTERVAL = Type('interval')
1032JSON = Type('json jsonb')
1033
1034# Type object for arrays (also equate to their base types):
1035
1036ARRAY = ArrayType()
1037
1038
1039# Mandatory type helpers defined by DB-API 2 specs:
1040
1041def Date(year, month, day):
1042    """Construct an object holding a date value."""
1043    return date(year, month, day)
1044
1045
1046def Time(hour, minute=0, second=0, microsecond=0):
1047    """Construct an object holding a time value."""
1048    return time(hour, minute, second, microsecond)
1049
1050
1051def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1052    """Construct an object holding a time stamp value."""
1053    return datetime(year, month, day, hour, minute, second, microsecond)
1054
1055
1056def DateFromTicks(ticks):
1057    """Construct an object holding a date value from the given ticks value."""
1058    return Date(*localtime(ticks)[:3])
1059
1060
1061def TimeFromTicks(ticks):
1062    """Construct an object holding a time value from the given ticks value."""
1063    return Time(*localtime(ticks)[3:6])
1064
1065
1066def TimestampFromTicks(ticks):
1067    """Construct an object holding a time stamp from the given ticks value."""
1068    return Timestamp(*localtime(ticks)[:6])
1069
1070
1071class Binary(bytes):
1072    """Construct an object capable of holding a binary (long) string value."""
1073
1074
1075# Additional type helpers for PyGreSQL:
1076
1077class Json:
1078    """Construct a wrapper for holding an object serializable to JSON."""
1079
1080    def __init__(self, obj, encode=None):
1081        self.obj = obj
1082        self.encode = encode or jsonencode
1083
1084    def __str__(self):
1085        obj = self.obj
1086        if isinstance(obj, basestring):
1087            return obj
1088        return self.encode(obj)
1089
1090    __pg_repr__ = __str__
1091
1092
1093# If run as script, print some information:
1094
1095if __name__ == '__main__':
1096    print('PyGreSQL version', version)
1097    print('')
1098    print(__doc__)
Note: See TracBrowser for help on using the repository browser.