source: trunk/pgdb.py @ 781

Last change on this file since 781 was 781, checked in by cito, 4 years ago

Add full support for PostgreSQL array types

At the core of this patch is a fast parser for the peculiar syntax of
literal array expressions in PostgreSQL that was added to the C module.
This is not trivial, because PostgreSQL arrays can be multidimensional
and the syntax is different from Python and SQL expressions.

The Python pg and pgdb modules make use of this parser so that they can
return database columns containing PostgreSQL arrays to Python as lists.
Also added quoting methods that allow passing PostgreSQL arrays as lists
to insert()/update() and execute/executemany(). These methods are simpler
and were implemented in Python but needed support from the regex module.

The patch also adds makes getresult() in pg automatically return bytea
values in unescaped form as bytes strings. Before, it was necessary to
call unescape_bytea manually. The pgdb module did this already.

The patch includes some more refactorings and simplifications regarding
the quoting and casting in pg and pgdb.

Some references to antique PostgreSQL types that are not used any more
in the supported PostgreSQL versions have been removed.

Also added documentation and tests for the new features.

  • Property svn:keywords set to Id
File size: 34.2 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 781 2016-01-25 20:44:52Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from re import compile as regex
76from json import loads as jsondecode, dumps as jsonencode
77
78try:
79    long
80except NameError:  # Python >= 3.0
81    long = int
82
83try:
84    unicode
85except NameError:  # Python >= 3.0
86    unicode = str
87
88try:
89    basestring
90except NameError:  # Python >= 3.0
91    basestring = (str, bytes)
92
93from collections import Iterable
94try:
95    from collections import OrderedDict
96except ImportError:  # Python 2.6 or 3.0
97    try:
98        from ordereddict import OrderedDict
99    except Exception:
100        def OrderedDict(*args):
101            raise NotSupportedError('OrderedDict is not supported')
102
103
104### Module Constants
105
106# compliant with DB API 2.0
107apilevel = '2.0'
108
109# module may be shared, but not connections
110threadsafety = 1
111
112# this module use extended python format codes
113paramstyle = 'pyformat'
114
115# shortcut methods have been excluded from DB API 2 and
116# are not recommended by the DB SIG, but they can be handy
117shortcutmethods = 1
118
119
120### Internal Types Handling
121
122def decimal_type(decimal_type=None):
123    """Get or set global type to be used for decimal values."""
124    global Decimal
125    if decimal_type is not None:
126        _cast['numeric'] = Decimal = decimal_type
127    return Decimal
128
129
130def _cast_bool(value):
131    return value[:1] in ('t', 'T')
132
133
134def _cast_money(value):
135    return Decimal(''.join(filter(
136        lambda v: v in '0123456789.-', value)))
137
138
139_cast = {'bool': _cast_bool, 'bytea': unescape_bytea,
140    'int2': int, 'int4': int, 'serial': int,
141    'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
142    'oid': long, 'oid8': long,
143    'float4': float, 'float8': float,
144    'numeric': Decimal, 'money': _cast_money}
145
146
147def _db_error(msg, cls=DatabaseError):
148    """Return DatabaseError with empty sqlstate attribute."""
149    error = cls(msg)
150    error.sqlstate = None
151    return error
152
153
154def _op_error(msg):
155    """Return OperationalError."""
156    return _db_error(msg, OperationalError)
157
158
159class TypeCache(dict):
160    """Cache for database types."""
161
162    def __init__(self, cnx):
163        """Initialize type cache for connection."""
164        super(TypeCache, self).__init__()
165        self._src = cnx.source()
166
167    @staticmethod
168    def typecast(typ, value):
169        """Cast value according to database type."""
170        if value is None:
171            # for NULL values, no typecast is necessary
172            return None
173        cast = _cast.get(typ)
174        if cast is None:
175            if typ.startswith('_'):
176                # cast as an array type
177                cast = _cast.get(typ[1:])
178                return cast_array(value, cast)
179            # no typecast available or necessary
180            return value
181        else:
182            return cast(value)
183
184    def getdescr(self, oid):
185        """Get name of database type with given oid."""
186        try:
187            return self[oid]
188        except KeyError:
189            self._src.execute(
190                "SELECT typname, typlen "
191                "FROM pg_type WHERE oid=%s" % oid)
192            res = self._src.fetch(1)[0]
193            # The column name is omitted from the return value.
194            # It will have to be prepended by the caller.
195            res = (res[0], None, int(res[1]), None, None, None)
196            self[oid] = res
197            return res
198
199
200_re_array_escape = regex(r'(["\\])')
201_re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
202
203class _quotedict(dict):
204    """Dictionary with auto quoting of its items.
205
206    The quote attribute must be set to the desired quote function.
207    """
208
209    def __getitem__(self, key):
210        return self.quote(super(_quotedict, self).__getitem__(key))
211
212
213### Cursor Object
214
215class Cursor(object):
216    """Cursor object."""
217
218    def __init__(self, dbcnx):
219        """Create a cursor object for the database connection."""
220        self.connection = self._dbcnx = dbcnx
221        self._cnx = dbcnx._cnx
222        self._type_cache = dbcnx._type_cache
223        self._src = self._cnx.source()
224        # the official attribute for describing the result columns
225        self.description = None
226        # unofficial attributes for convenience and performance
227        self.colnames = self.coltypes = None
228        if self.row_factory is Cursor.row_factory:
229            # the row factory needs to be determined dynamically
230            self.row_factory = None
231        else:
232            self.build_row_factory = None
233        self.rowcount = -1
234        self.arraysize = 1
235        self.lastrowid = None
236
237    def __iter__(self):
238        """Make cursor compatible to the iteration protocol."""
239        return self
240
241    def __enter__(self):
242        """Enter the runtime context for the cursor object."""
243        return self
244
245    def __exit__(self, et, ev, tb):
246        """Exit the runtime context for the cursor object."""
247        self.close()
248
249    def _quote(self, val):
250        """Quote value depending on its type."""
251        if val is None:
252            return 'NULL'
253        if isinstance(val, (datetime, date, time, timedelta, Json)):
254            val = str(val)
255        if isinstance(val, basestring):
256            if isinstance(val, Binary):
257                val = self._cnx.escape_bytea(val)
258                if bytes is not str:  # Python >= 3.0
259                    val = val.decode('ascii')
260            else:
261                val = self._cnx.escape_string(val)
262            return "'%s'" % val
263        if isinstance(val, float):
264            if isinf(val):
265                return "'-Infinity'" if val < 0 else "'Infinity'"
266            if isnan(val):
267                return "'NaN'"
268            return val
269        if isinstance(val, (int, long, Decimal)):
270            return val
271        if isinstance(val, list):
272            return "'%s'" % self._quote_array(val)
273        if isinstance(val, tuple):
274            q = self._quote
275            return 'ROW(%s)' % ','.join(str(q(v)) for v in val)
276        try:
277            return val.__pg_repr__()
278        except AttributeError:
279            raise InterfaceError(
280                'do not know how to handle type %s' % type(val))
281
282    def _quote_array(self, val):
283        """Quote value as a literal constant for an array."""
284        # We could also cast to an array constructor here, but that is more
285        # verbose and you need to know the base type to build emtpy arrays.
286        if isinstance(val, list):
287            return '{%s}' % ','.join(self._quote_array(v) for v in val)
288        if val is None:
289            return 'null'
290        if isinstance(val, (int, long, float)):
291            return str(val)
292        if isinstance(val, bool):
293            return 't' if val else 'f'
294        if isinstance(val, basestring):
295            if not val:
296                return '""'
297            if _re_array_quote.search(val):
298                return '"%s"' % _re_array_escape.sub(r'\\\1', val)
299            return val
300        try:
301            return val.__pg_repr__()
302        except AttributeError:
303            raise InterfaceError(
304                'do not know how to handle type %s' % type(val))
305
306    def _quoteparams(self, string, parameters):
307        """Quote parameters.
308
309        This function works for both mappings and sequences.
310        """
311        if isinstance(parameters, dict):
312            parameters = _quotedict(parameters)
313            parameters.quote = self._quote
314        else:
315            parameters = tuple(map(self._quote, parameters))
316        return string % parameters
317
318    def close(self):
319        """Close the cursor object."""
320        self._src.close()
321        self.description = None
322        self.colnames = self.coltypes = None
323        self.rowcount = -1
324        self.lastrowid = None
325
326    def execute(self, operation, parameters=None):
327        """Prepare and execute a database operation (query or command)."""
328        # The parameters may also be specified as list of tuples to e.g.
329        # insert multiple rows in a single operation, but this kind of
330        # usage is deprecated.  We make several plausibility checks because
331        # tuples can also be passed with the meaning of ROW constructors.
332        if (parameters and isinstance(parameters, list)
333                and len(parameters) > 1
334                and all(isinstance(p, tuple) for p in parameters)
335                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
336            return self.executemany(operation, parameters)
337        else:
338            # not a list of tuples
339            return self.executemany(operation, [parameters])
340
341    def executemany(self, operation, seq_of_parameters):
342        """Prepare operation and execute it against a parameter sequence."""
343        if not seq_of_parameters:
344            # don't do anything without parameters
345            return
346        self.description = None
347        self.colnames = self.coltypes = None
348        self.rowcount = -1
349        # first try to execute all queries
350        rowcount = 0
351        sql = "BEGIN"
352        try:
353            if not self._dbcnx._tnx:
354                try:
355                    self._cnx.source().execute(sql)
356                except DatabaseError:
357                    raise  # database provides error message
358                except Exception as err:
359                    raise _op_error("can't start transaction")
360                self._dbcnx._tnx = True
361            for parameters in seq_of_parameters:
362                sql = operation
363                if parameters:
364                    sql = self._quoteparams(sql, parameters)
365                rows = self._src.execute(sql)
366                if rows:  # true if not DML
367                    rowcount += rows
368                else:
369                    self.rowcount = -1
370        except DatabaseError:
371            raise  # database provides error message
372        except Error as err:
373            raise _db_error(
374                "error in '%s': '%s' " % (sql, err), InterfaceError)
375        except Exception as err:
376            raise _op_error("internal error in '%s': %s" % (sql, err))
377        # then initialize result raw count and description
378        if self._src.resulttype == RESULT_DQL:
379            self.rowcount = self._src.ntuples
380            getdescr = self._type_cache.getdescr
381            description = [CursorDescription(
382                info[1], *getdescr(info[2])) for info in self._src.listinfo()]
383            self.colnames = [info[0] for info in description]
384            self.coltypes = [info[1] for info in description]
385            self.description = description
386            self.lastrowid = None
387            if self.build_row_factory:
388                self.row_factory = self.build_row_factory()
389        else:
390            self.rowcount = rowcount
391            self.lastrowid = self._src.oidstatus()
392        # return the cursor object, so you can write statements such as
393        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
394        return self
395
396    def fetchone(self):
397        """Fetch the next row of a query result set."""
398        res = self.fetchmany(1, False)
399        try:
400            return res[0]
401        except IndexError:
402            return None
403
404    def fetchall(self):
405        """Fetch all (remaining) rows of a query result."""
406        return self.fetchmany(-1, False)
407
408    def fetchmany(self, size=None, keep=False):
409        """Fetch the next set of rows of a query result.
410
411        The number of rows to fetch per call is specified by the
412        size parameter. If it is not given, the cursor's arraysize
413        determines the number of rows to be fetched. If you set
414        the keep parameter to true, this is kept as new arraysize.
415        """
416        if size is None:
417            size = self.arraysize
418        if keep:
419            self.arraysize = size
420        try:
421            result = self._src.fetch(size)
422        except DatabaseError:
423            raise
424        except Error as err:
425            raise _db_error(str(err))
426        typecast = self._type_cache.typecast
427        return [self.row_factory([typecast(typ, value)
428            for typ, value in zip(self.coltypes, row)]) for row in result]
429
430    def callproc(self, procname, parameters=None):
431        """Call a stored database procedure with the given name.
432
433        The sequence of parameters must contain one entry for each input
434        argument that the procedure expects. The result of the call is the
435        same as this input sequence; replacement of output and input/output
436        parameters in the return value is currently not supported.
437
438        The procedure may also provide a result set as output. These can be
439        requested through the standard fetch methods of the cursor.
440        """
441        n = parameters and len(parameters) or 0
442        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
443        self.execute(query, parameters)
444        return parameters
445
446    def copy_from(self, stream, table,
447            format=None, sep=None, null=None, size=None, columns=None):
448        """Copy data from an input stream to the specified table.
449
450        The input stream can be a file-like object with a read() method or
451        it can also be an iterable returning a row or multiple rows of input
452        on each iteration.
453
454        The format must be text, csv or binary. The sep option sets the
455        column separator (delimiter) used in the non binary formats.
456        The null option sets the textual representation of NULL in the input.
457
458        The size option sets the size of the buffer used when reading data
459        from file-like objects.
460
461        The copy operation can be restricted to a subset of columns. If no
462        columns are specified, all of them will be copied.
463        """
464        binary_format = format == 'binary'
465        try:
466            read = stream.read
467        except AttributeError:
468            if size:
469                raise ValueError("size must only be set for file-like objects")
470            if binary_format:
471                input_type = bytes
472                type_name = 'byte strings'
473            else:
474                input_type = basestring
475                type_name = 'strings'
476
477            if isinstance(stream, basestring):
478                if not isinstance(stream, input_type):
479                    raise ValueError("The input must be %s" % type_name)
480                if not binary_format:
481                    if isinstance(stream, str):
482                        if not stream.endswith('\n'):
483                            stream += '\n'
484                    else:
485                        if not stream.endswith(b'\n'):
486                            stream += b'\n'
487
488                def chunks():
489                    yield stream
490
491            elif isinstance(stream, Iterable):
492
493                def chunks():
494                    for chunk in stream:
495                        if not isinstance(chunk, input_type):
496                            raise ValueError(
497                                "Input stream must consist of %s" % type_name)
498                        if isinstance(chunk, str):
499                            if not chunk.endswith('\n'):
500                                chunk += '\n'
501                        else:
502                            if not chunk.endswith(b'\n'):
503                                chunk += b'\n'
504                        yield chunk
505
506            else:
507                raise TypeError("Need an input stream to copy from")
508        else:
509            if size is None:
510                size = 8192
511            elif not isinstance(size, int):
512                raise TypeError("The size option must be an integer")
513            if size > 0:
514
515                def chunks():
516                    while True:
517                        buffer = read(size)
518                        yield buffer
519                        if not buffer or len(buffer) < size:
520                            break
521
522            else:
523
524                def chunks():
525                    yield read()
526
527        if not table or not isinstance(table, basestring):
528            raise TypeError("Need a table to copy to")
529        if table.lower().startswith('select'):
530                raise ValueError("Must specify a table, not a query")
531        else:
532            table = '"%s"' % (table,)
533        operation = ['copy %s' % (table,)]
534        options = []
535        params = []
536        if format is not None:
537            if not isinstance(format, basestring):
538                raise TypeError("format option must be be a string")
539            if format not in ('text', 'csv', 'binary'):
540                raise ValueError("invalid format")
541            options.append('format %s' % (format,))
542        if sep is not None:
543            if not isinstance(sep, basestring):
544                raise TypeError("sep option must be a string")
545            if format == 'binary':
546                raise ValueError("sep is not allowed with binary format")
547            if len(sep) != 1:
548                raise ValueError("sep must be a single one-byte character")
549            options.append('delimiter %s')
550            params.append(sep)
551        if null is not None:
552            if not isinstance(null, basestring):
553                raise TypeError("null option must be a string")
554            options.append('null %s')
555            params.append(null)
556        if columns:
557            if not isinstance(columns, basestring):
558                columns = ','.join('"%s"' % (col,) for col in columns)
559            operation.append('(%s)' % (columns,))
560        operation.append("from stdin")
561        if options:
562            operation.append('(%s)' % ','.join(options))
563        operation = ' '.join(operation)
564
565        putdata = self._src.putdata
566        self.execute(operation, params)
567
568        try:
569            for chunk in chunks():
570                putdata(chunk)
571        except BaseException as error:
572            self.rowcount = -1
573            # the following call will re-raise the error
574            putdata(error)
575        else:
576            self.rowcount = putdata(None)
577
578        # return the cursor object, so you can chain operations
579        return self
580
581    def copy_to(self, stream, table,
582            format=None, sep=None, null=None, decode=None, columns=None):
583        """Copy data from the specified table to an output stream.
584
585        The output stream can be a file-like object with a write() method or
586        it can also be None, in which case the method will return a generator
587        yielding a row on each iteration.
588
589        Output will be returned as byte strings unless you set decode to true.
590
591        Note that you can also use a select query instead of the table name.
592
593        The format must be text, csv or binary. The sep option sets the
594        column separator (delimiter) used in the non binary formats.
595        The null option sets the textual representation of NULL in the output.
596
597        The copy operation can be restricted to a subset of columns. If no
598        columns are specified, all of them will be copied.
599        """
600        binary_format = format == 'binary'
601        if stream is not None:
602            try:
603                write = stream.write
604            except AttributeError:
605                raise TypeError("need an output stream to copy to")
606        if not table or not isinstance(table, basestring):
607            raise TypeError("need a table to copy to")
608        if table.lower().startswith('select'):
609            if columns:
610                raise ValueError("columns must be specified in the query")
611            table = '(%s)' % (table,)
612        else:
613            table = '"%s"' % (table,)
614        operation = ['copy %s' % (table,)]
615        options = []
616        params = []
617        if format is not None:
618            if not isinstance(format, basestring):
619                raise TypeError("format option must be a string")
620            if format not in ('text', 'csv', 'binary'):
621                raise ValueError("invalid format")
622            options.append('format %s' % (format,))
623        if sep is not None:
624            if not isinstance(sep, basestring):
625                raise TypeError("sep option must be a string")
626            if binary_format:
627                raise ValueError("sep is not allowed with binary format")
628            if len(sep) != 1:
629                raise ValueError("sep must be a single one-byte character")
630            options.append('delimiter %s')
631            params.append(sep)
632        if null is not None:
633            if not isinstance(null, basestring):
634                raise TypeError("null option must be a string")
635            options.append('null %s')
636            params.append(null)
637        if decode is None:
638            if format == 'binary':
639                decode = False
640            else:
641                decode = str is unicode
642        else:
643            if not isinstance(decode, (int, bool)):
644                raise TypeError("decode option must be a boolean")
645            if decode and binary_format:
646                raise ValueError("decode is not allowed with binary format")
647        if columns:
648            if not isinstance(columns, basestring):
649                columns = ','.join('"%s"' % (col,) for col in columns)
650            operation.append('(%s)' % (columns,))
651
652        operation.append("to stdout")
653        if options:
654            operation.append('(%s)' % ','.join(options))
655        operation = ' '.join(operation)
656
657        getdata = self._src.getdata
658        self.execute(operation, params)
659
660        def copy():
661            self.rowcount = 0
662            while True:
663                row = getdata(decode)
664                if isinstance(row, int):
665                    if self.rowcount != row:
666                        self.rowcount = row
667                    break
668                self.rowcount += 1
669                yield row
670
671        if stream is None:
672            # no input stream, return the generator
673            return copy()
674
675        # write the rows to the file-like input stream
676        for row in copy():
677            write(row)
678
679        # return the cursor object, so you can chain operations
680        return self
681
682    def __next__(self):
683        """Return the next row (support for the iteration protocol)."""
684        res = self.fetchone()
685        if res is None:
686            raise StopIteration
687        return res
688
689    # Note that since Python 2.6 the iterator protocol uses __next()__
690    # instead of next(), we keep it only for backward compatibility of pgdb.
691    next = __next__
692
693    @staticmethod
694    def nextset():
695        """Not supported."""
696        raise NotSupportedError("nextset() is not supported")
697
698    @staticmethod
699    def setinputsizes(sizes):
700        """Not supported."""
701        pass  # unsupported, but silently passed
702
703    @staticmethod
704    def setoutputsize(size, column=0):
705        """Not supported."""
706        pass  # unsupported, but silently passed
707
708    @staticmethod
709    def row_factory(row):
710        """Process rows before they are returned.
711
712        You can overwrite this statically with a custom row factory, or
713        you can build a row factory dynamically with build_row_factory().
714
715        For example, you can create a Cursor class that returns rows as
716        Python dictionaries like this:
717
718            class DictCursor(pgdb.Cursor):
719
720                def row_factory(self, row):
721                    return {desc[0]: value
722                        for desc, value in zip(self.description, row)}
723
724            cur = DictCursor(con)  # get one DictCursor instance or
725            con.cursor_type = DictCursor  # always use DictCursor instances
726        """
727        raise NotImplementedError
728
729    def build_row_factory(self):
730        """Build a row factory based on the current description.
731
732        This implementation builds a row factory for creating named tuples.
733        You can overwrite this method if you want to dynamically create
734        different row factories whenever the column description changes.
735        """
736        colnames = self.colnames
737        if colnames:
738            try:
739                try:
740                    return namedtuple('Row', colnames, rename=True)._make
741                except TypeError:  # Python 2.6 and 3.0 do not support rename
742                    colnames = [v if v.isalnum() else 'column_%d' % n
743                             for n, v in enumerate(colnames)]
744                    return namedtuple('Row', colnames)._make
745            except ValueError:  # there is still a problem with the field names
746                colnames = ['column_%d' % n for n in range(len(colnames))]
747                return namedtuple('Row', colnames)._make
748
749
750CursorDescription = namedtuple('CursorDescription',
751    ['name', 'type_code', 'display_size', 'internal_size',
752     'precision', 'scale', 'null_ok'])
753
754
755### Connection Objects
756
757class Connection(object):
758    """Connection object."""
759
760    # expose the exceptions as attributes on the connection object
761    Error = Error
762    Warning = Warning
763    InterfaceError = InterfaceError
764    DatabaseError = DatabaseError
765    InternalError = InternalError
766    OperationalError = OperationalError
767    ProgrammingError = ProgrammingError
768    IntegrityError = IntegrityError
769    DataError = DataError
770    NotSupportedError = NotSupportedError
771
772    def __init__(self, cnx):
773        """Create a database connection object."""
774        self._cnx = cnx  # connection
775        self._tnx = False  # transaction state
776        self._type_cache = TypeCache(cnx)
777        self.cursor_type = Cursor
778        try:
779            self._cnx.source()
780        except Exception:
781            raise _op_error("invalid connection")
782
783    def __enter__(self):
784        """Enter the runtime context for the connection object.
785
786        The runtime context can be used for running transactions.
787        """
788        return self
789
790    def __exit__(self, et, ev, tb):
791        """Exit the runtime context for the connection object.
792
793        This does not close the connection, but it ends a transaction.
794        """
795        if et is None and ev is None and tb is None:
796            self.commit()
797        else:
798            self.rollback()
799
800    def close(self):
801        """Close the connection object."""
802        if self._cnx:
803            if self._tnx:
804                try:
805                    self.rollback()
806                except DatabaseError:
807                    pass
808            self._cnx.close()
809            self._cnx = None
810        else:
811            raise _op_error("connection has been closed")
812
813    def commit(self):
814        """Commit any pending transaction to the database."""
815        if self._cnx:
816            if self._tnx:
817                self._tnx = False
818                try:
819                    self._cnx.source().execute("COMMIT")
820                except DatabaseError:
821                    raise
822                except Exception:
823                    raise _op_error("can't commit")
824        else:
825            raise _op_error("connection has been closed")
826
827    def rollback(self):
828        """Roll back to the start of any pending transaction."""
829        if self._cnx:
830            if self._tnx:
831                self._tnx = False
832                try:
833                    self._cnx.source().execute("ROLLBACK")
834                except DatabaseError:
835                    raise
836                except Exception:
837                    raise _op_error("can't rollback")
838        else:
839            raise _op_error("connection has been closed")
840
841    def cursor(self):
842        """Return a new cursor object using the connection."""
843        if self._cnx:
844            try:
845                return self.cursor_type(self)
846            except Exception:
847                raise _op_error("invalid connection")
848        else:
849            raise _op_error("connection has been closed")
850
851    if shortcutmethods:  # otherwise do not implement and document this
852
853        def execute(self, operation, params=None):
854            """Shortcut method to run an operation on an implicit cursor."""
855            cursor = self.cursor()
856            cursor.execute(operation, params)
857            return cursor
858
859        def executemany(self, operation, param_seq):
860            """Shortcut method to run an operation against a sequence."""
861            cursor = self.cursor()
862            cursor.executemany(operation, param_seq)
863            return cursor
864
865
866### Module Interface
867
868_connect_ = connect
869
870def connect(dsn=None,
871        user=None, password=None,
872        host=None, database=None):
873    """Connect to a database."""
874    # first get params from DSN
875    dbport = -1
876    dbhost = ""
877    dbbase = ""
878    dbuser = ""
879    dbpasswd = ""
880    dbopt = ""
881    try:
882        params = dsn.split(":")
883        dbhost = params[0]
884        dbbase = params[1]
885        dbuser = params[2]
886        dbpasswd = params[3]
887        dbopt = params[4]
888    except (AttributeError, IndexError, TypeError):
889        pass
890
891    # override if necessary
892    if user is not None:
893        dbuser = user
894    if password is not None:
895        dbpasswd = password
896    if database is not None:
897        dbbase = database
898    if host is not None:
899        try:
900            params = host.split(":")
901            dbhost = params[0]
902            dbport = int(params[1])
903        except (AttributeError, IndexError, TypeError, ValueError):
904            pass
905
906    # empty host is localhost
907    if dbhost == "":
908        dbhost = None
909    if dbuser == "":
910        dbuser = None
911
912    # open the connection
913    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
914    return Connection(cnx)
915
916
917### Types Handling
918
919class Type(frozenset):
920    """Type class for a couple of PostgreSQL data types.
921
922    PostgreSQL is object-oriented: types are dynamic.
923    We must thus use type names as internal type codes.
924    """
925
926    def __new__(cls, values):
927        if isinstance(values, basestring):
928            values = values.split()
929        return super(Type, cls).__new__(cls, values)
930
931    def __eq__(self, other):
932        if isinstance(other, basestring):
933            if other.startswith('_'):
934                other = other[1:]
935            return other in self
936        else:
937            return super(Type, self).__eq__(other)
938
939    def __ne__(self, other):
940        if isinstance(other, basestring):
941            if other.startswith('_'):
942                other = other[1:]
943            return other not in self
944        else:
945            return super(Type, self).__ne__(other)
946
947
948class ArrayType:
949    """Type class for PostgreSQL array types."""
950
951    def __eq__(self, other):
952        if isinstance(other, basestring):
953            return other.startswith('_')
954        else:
955            return isinstance(other, ArrayType)
956
957    def __ne__(self, other):
958        if isinstance(other, basestring):
959            return not other.startswith('_')
960        else:
961            return not isinstance(other, ArrayType)
962
963
964# Mandatory type objects defined by DB-API 2 specs:
965
966STRING = Type('char bpchar name text varchar')
967BINARY = Type('bytea')
968NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
969DATETIME = Type('date time timetz timestamp timestamptz interval'
970    ' abstime reltime')  # these are very old
971ROWID = Type('oid')
972
973
974# Additional type objects (more specific):
975
976BOOL = Type('bool')
977SMALLINT = Type('int2')
978INTEGER = Type('int2 int4 int8 serial')
979LONG = Type('int8')
980FLOAT = Type('float4 float8')
981NUMERIC = Type('numeric')
982MONEY = Type('money')
983DATE = Type('date')
984TIME = Type('time timetz')
985TIMESTAMP = Type('timestamp timestamptz')
986INTERVAL = Type('interval')
987JSON = Type('json jsonb')
988
989# Type object for arrays (also equate to their base types):
990
991ARRAY = ArrayType()
992
993
994# Mandatory type helpers defined by DB-API 2 specs:
995
996def Date(year, month, day):
997    """Construct an object holding a date value."""
998    return date(year, month, day)
999
1000
1001def Time(hour, minute=0, second=0, microsecond=0):
1002    """Construct an object holding a time value."""
1003    return time(hour, minute, second, microsecond)
1004
1005
1006def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1007    """Construct an object holding a time stamp value."""
1008    return datetime(year, month, day, hour, minute, second, microsecond)
1009
1010
1011def DateFromTicks(ticks):
1012    """Construct an object holding a date value from the given ticks value."""
1013    return Date(*localtime(ticks)[:3])
1014
1015
1016def TimeFromTicks(ticks):
1017    """Construct an object holding a time value from the given ticks value."""
1018    return Time(*localtime(ticks)[3:6])
1019
1020
1021def TimestampFromTicks(ticks):
1022    """Construct an object holding a time stamp from the given ticks value."""
1023    return Timestamp(*localtime(ticks)[:6])
1024
1025
1026class Binary(bytes):
1027    """Construct an object capable of holding a binary (long) string value."""
1028
1029
1030# Additional type helpers for PyGreSQL:
1031
1032class Json:
1033    """Construct a wrapper for holding an object serializable to JSON."""
1034
1035    def __init__(self, obj, encode=None):
1036        self.obj = obj
1037        self.encode = encode or jsonencode
1038
1039    def __str__(self):
1040        obj = self.obj
1041        if isinstance(obj, basestring):
1042            return obj
1043        return self.encode(obj)
1044
1045    __pg_repr__ = __str__
1046
1047
1048# If run as script, print some information:
1049
1050if __name__ == '__main__':
1051    print('PyGreSQL version', version)
1052    print('')
1053    print(__doc__)
Note: See TracBrowser for help on using the repository browser.