source: trunk/pgdb.py @ 788

Last change on this file since 788 was 788, checked in by cito, 3 years ago

Add method columns() to type cache

The method returns the columns of composite types.
This makes the type cache even more useful.

  • Property svn:keywords set to Id
File size: 36.1 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 788 2016-01-26 21:13:55Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from re import compile as regex
76from json import loads as jsondecode, dumps as jsonencode
77
78try:
79    long
80except NameError:  # Python >= 3.0
81    long = int
82
83try:
84    unicode
85except NameError:  # Python >= 3.0
86    unicode = str
87
88try:
89    basestring
90except NameError:  # Python >= 3.0
91    basestring = (str, bytes)
92
93from collections import Iterable
94try:
95    from collections import OrderedDict
96except ImportError:  # Python 2.6 or 3.0
97    try:
98        from ordereddict import OrderedDict
99    except Exception:
100        def OrderedDict(*args):
101            raise NotSupportedError('OrderedDict is not supported')
102
103
104### Module Constants
105
106# compliant with DB API 2.0
107apilevel = '2.0'
108
109# module may be shared, but not connections
110threadsafety = 1
111
112# this module use extended python format codes
113paramstyle = 'pyformat'
114
115# shortcut methods have been excluded from DB API 2 and
116# are not recommended by the DB SIG, but they can be handy
117shortcutmethods = 1
118
119
120### Internal Types Handling
121
122def decimal_type(decimal_type=None):
123    """Get or set global type to be used for decimal values."""
124    global Decimal
125    if decimal_type is not None:
126        _cast['numeric'] = Decimal = decimal_type
127    return Decimal
128
129
130def _cast_bool(value):
131    return value[:1] in ('t', 'T')
132
133
134def _cast_money(value):
135    return Decimal(''.join(filter(
136        lambda v: v in '0123456789.-', value)))
137
138
139_cast = {'bool': _cast_bool, 'bytea': unescape_bytea,
140    'int2': int, 'int4': int, 'serial': int,
141    'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
142    'oid': long, 'oid8': long,
143    'float4': float, 'float8': float,
144    'numeric': Decimal, 'money': _cast_money}
145
146
147def _db_error(msg, cls=DatabaseError):
148    """Return DatabaseError with empty sqlstate attribute."""
149    error = cls(msg)
150    error.sqlstate = None
151    return error
152
153
154def _op_error(msg):
155    """Return OperationalError."""
156    return _db_error(msg, OperationalError)
157
158
159TypeInfo = namedtuple('TypeInfo',
160    ['oid', 'name', 'len', 'type', 'category', 'delim', 'relid'])
161
162ColumnInfo = namedtuple('ColumnInfo', ['name', 'type'])
163
164
165class TypeCache(dict):
166    """Cache for database types.
167
168    This cache maps type OIDs and names to TypeInfo tuples containing
169    important information on the associated database type.
170    """
171
172    def __init__(self, cnx):
173        """Initialize type cache for connection."""
174        super(TypeCache, self).__init__()
175        self._escape_string = cnx.escape_string
176        self._src = cnx.source()
177
178    def __missing__(self, key):
179        """Get the type info from the database if it is not cached."""
180        if isinstance(key, int):
181            oid = key
182        else:
183            if not '.' in key and not '"' in key:
184                key = '"%s"' % key
185            oid = "'%s'::regtype" % self._escape_string(key)
186        self._src.execute("SELECT oid, typname,"
187             " typlen, typtype, typcategory, typdelim, typrelid"
188            " FROM pg_type WHERE oid=%s" % oid)
189        res = self._src.fetch(1)
190        if not res:
191            raise KeyError('Type %s could not be found' % key)
192        res = list(res[0])
193        res[0] = int(res[0])
194        res[2] = int(res[2])
195        res[6] = int(res[6])
196        res = TypeInfo(*res)
197        self[res.oid] = self[res.name] = res
198        return res
199
200    def columns(self, key):
201        """Get the names and types of the columns of composite types."""
202        typ = self[key]
203        if typ.type != 'c' or not typ.relid:
204            return []  # this type is not composite
205        self._src.execute("SELECT attname, atttypid"
206            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
207            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
208        return [ColumnInfo(name, int(oid))
209            for name, oid in self._src.fetch(-1)]
210
211    @staticmethod
212    def typecast(typ, value):
213        """Cast value according to database type."""
214        if value is None:
215            # for NULL values, no typecast is necessary
216            return None
217        cast = _cast.get(typ)
218        if cast is None:
219            if typ.startswith('_'):
220                # cast as an array type
221                cast = _cast.get(typ[1:])
222                return cast_array(value, cast)
223            # no typecast available or necessary
224            return value
225        else:
226            return cast(value)
227
228
229_re_array_escape = regex(r'(["\\])')
230_re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
231
232
233class _quotedict(dict):
234    """Dictionary with auto quoting of its items.
235
236    The quote attribute must be set to the desired quote function.
237    """
238
239    def __getitem__(self, key):
240        return self.quote(super(_quotedict, self).__getitem__(key))
241
242
243### Cursor Object
244
245class Cursor(object):
246    """Cursor object."""
247
248    def __init__(self, dbcnx):
249        """Create a cursor object for the database connection."""
250        self.connection = self._dbcnx = dbcnx
251        self._cnx = dbcnx._cnx
252        self.type_cache = dbcnx.type_cache
253        self._src = self._cnx.source()
254        # the official attribute for describing the result columns
255        self._description = None
256        if self.row_factory is Cursor.row_factory:
257            # the row factory needs to be determined dynamically
258            self.row_factory = None
259        else:
260            self.build_row_factory = None
261        self.rowcount = -1
262        self.arraysize = 1
263        self.lastrowid = None
264
265    def __iter__(self):
266        """Make cursor compatible to the iteration protocol."""
267        return self
268
269    def __enter__(self):
270        """Enter the runtime context for the cursor object."""
271        return self
272
273    def __exit__(self, et, ev, tb):
274        """Exit the runtime context for the cursor object."""
275        self.close()
276
277    def _quote(self, val):
278        """Quote value depending on its type."""
279        if val is None:
280            return 'NULL'
281        if isinstance(val, (datetime, date, time, timedelta, Json)):
282            val = str(val)
283        if isinstance(val, basestring):
284            if isinstance(val, Binary):
285                val = self._cnx.escape_bytea(val)
286                if bytes is not str:  # Python >= 3.0
287                    val = val.decode('ascii')
288            else:
289                val = self._cnx.escape_string(val)
290            return "'%s'" % val
291        if isinstance(val, float):
292            if isinf(val):
293                return "'-Infinity'" if val < 0 else "'Infinity'"
294            if isnan(val):
295                return "'NaN'"
296            return val
297        if isinstance(val, (int, long, Decimal)):
298            return val
299        if isinstance(val, list):
300            return "'%s'" % self._quote_array(val)
301        if isinstance(val, tuple):
302            q = self._quote
303            return 'ROW(%s)' % ','.join(str(q(v)) for v in val)
304        try:
305            return val.__pg_repr__()
306        except AttributeError:
307            raise InterfaceError(
308                'do not know how to handle type %s' % type(val))
309
310    def _quote_array(self, val):
311        """Quote value as a literal constant for an array."""
312        # We could also cast to an array constructor here, but that is more
313        # verbose and you need to know the base type to build emtpy arrays.
314        if isinstance(val, list):
315            return '{%s}' % ','.join(self._quote_array(v) for v in val)
316        if val is None:
317            return 'null'
318        if isinstance(val, (int, long, float)):
319            return str(val)
320        if isinstance(val, bool):
321            return 't' if val else 'f'
322        if isinstance(val, basestring):
323            if not val:
324                return '""'
325            if _re_array_quote.search(val):
326                return '"%s"' % _re_array_escape.sub(r'\\\1', val)
327            return val
328        try:
329            return val.__pg_repr__()
330        except AttributeError:
331            raise InterfaceError(
332                'do not know how to handle type %s' % type(val))
333
334    def _quoteparams(self, string, parameters):
335        """Quote parameters.
336
337        This function works for both mappings and sequences.
338        """
339        if isinstance(parameters, dict):
340            parameters = _quotedict(parameters)
341            parameters.quote = self._quote
342        else:
343            parameters = tuple(map(self._quote, parameters))
344        return string % parameters
345
346    def _make_description(self, info):
347        """Make the description tuple for the given field info."""
348        name, typ, size, mod = info[1:]
349        type_info = self.type_cache[typ]
350        type_code = type_info.name
351        if mod > 0:
352            mod -= 4
353        if type_code == 'numeric':
354            precision, scale = mod >> 16, mod & 0xffff
355            size = precision
356        else:
357            if not size:
358                size = type_info.size
359            if size == -1:
360                size = mod
361            precision = scale = None
362        return CursorDescription(name, type_code,
363            None, size, precision, scale, None)
364
365    @property
366    def description(self):
367        """Read-only attribute describing the result columns."""
368        descr = self._description
369        if self._description is True:
370            make = self._make_description
371            descr = [make(info) for info in self._src.listinfo()]
372            self._description = descr
373        return descr
374
375    @property
376    def colnames(self):
377        """Unofficial convenience method for getting the column names."""
378        return [d[0] for d in self.description]
379
380    @property
381    def coltypes(self):
382        """Unofficial convenience method for getting the column types."""
383        return [d[1] for d in self.description]
384
385    def close(self):
386        """Close the cursor object."""
387        self._src.close()
388        self._description = None
389        self.rowcount = -1
390        self.lastrowid = None
391
392    def execute(self, operation, parameters=None):
393        """Prepare and execute a database operation (query or command)."""
394        # The parameters may also be specified as list of tuples to e.g.
395        # insert multiple rows in a single operation, but this kind of
396        # usage is deprecated.  We make several plausibility checks because
397        # tuples can also be passed with the meaning of ROW constructors.
398        if (parameters and isinstance(parameters, list)
399                and len(parameters) > 1
400                and all(isinstance(p, tuple) for p in parameters)
401                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
402            return self.executemany(operation, parameters)
403        else:
404            # not a list of tuples
405            return self.executemany(operation, [parameters])
406
407    def executemany(self, operation, seq_of_parameters):
408        """Prepare operation and execute it against a parameter sequence."""
409        if not seq_of_parameters:
410            # don't do anything without parameters
411            return
412        self._description = None
413        self.rowcount = -1
414        # first try to execute all queries
415        rowcount = 0
416        sql = "BEGIN"
417        try:
418            if not self._dbcnx._tnx:
419                try:
420                    self._cnx.source().execute(sql)
421                except DatabaseError:
422                    raise  # database provides error message
423                except Exception as err:
424                    raise _op_error("can't start transaction")
425                self._dbcnx._tnx = True
426            for parameters in seq_of_parameters:
427                sql = operation
428                if parameters:
429                    sql = self._quoteparams(sql, parameters)
430                rows = self._src.execute(sql)
431                if rows:  # true if not DML
432                    rowcount += rows
433                else:
434                    self.rowcount = -1
435        except DatabaseError:
436            raise  # database provides error message
437        except Error as err:
438            raise _db_error(
439                "error in '%s': '%s' " % (sql, err), InterfaceError)
440        except Exception as err:
441            raise _op_error("internal error in '%s': %s" % (sql, err))
442        # then initialize result raw count and description
443        if self._src.resulttype == RESULT_DQL:
444            self._description = True  # fetch on demand
445            self.rowcount = self._src.ntuples
446            self.lastrowid = None
447            if self.build_row_factory:
448                self.row_factory = self.build_row_factory()
449        else:
450            self.rowcount = rowcount
451            self.lastrowid = self._src.oidstatus()
452        # return the cursor object, so you can write statements such as
453        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
454        return self
455
456    def fetchone(self):
457        """Fetch the next row of a query result set."""
458        res = self.fetchmany(1, False)
459        try:
460            return res[0]
461        except IndexError:
462            return None
463
464    def fetchall(self):
465        """Fetch all (remaining) rows of a query result."""
466        return self.fetchmany(-1, False)
467
468    def fetchmany(self, size=None, keep=False):
469        """Fetch the next set of rows of a query result.
470
471        The number of rows to fetch per call is specified by the
472        size parameter. If it is not given, the cursor's arraysize
473        determines the number of rows to be fetched. If you set
474        the keep parameter to true, this is kept as new arraysize.
475        """
476        if size is None:
477            size = self.arraysize
478        if keep:
479            self.arraysize = size
480        try:
481            result = self._src.fetch(size)
482        except DatabaseError:
483            raise
484        except Error as err:
485            raise _db_error(str(err))
486        typecast = self.type_cache.typecast
487        return [self.row_factory([typecast(typ, value)
488            for typ, value in zip(self.coltypes, row)]) for row in result]
489
490    def callproc(self, procname, parameters=None):
491        """Call a stored database procedure with the given name.
492
493        The sequence of parameters must contain one entry for each input
494        argument that the procedure expects. The result of the call is the
495        same as this input sequence; replacement of output and input/output
496        parameters in the return value is currently not supported.
497
498        The procedure may also provide a result set as output. These can be
499        requested through the standard fetch methods of the cursor.
500        """
501        n = parameters and len(parameters) or 0
502        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
503        self.execute(query, parameters)
504        return parameters
505
506    def copy_from(self, stream, table,
507            format=None, sep=None, null=None, size=None, columns=None):
508        """Copy data from an input stream to the specified table.
509
510        The input stream can be a file-like object with a read() method or
511        it can also be an iterable returning a row or multiple rows of input
512        on each iteration.
513
514        The format must be text, csv or binary. The sep option sets the
515        column separator (delimiter) used in the non binary formats.
516        The null option sets the textual representation of NULL in the input.
517
518        The size option sets the size of the buffer used when reading data
519        from file-like objects.
520
521        The copy operation can be restricted to a subset of columns. If no
522        columns are specified, all of them will be copied.
523        """
524        binary_format = format == 'binary'
525        try:
526            read = stream.read
527        except AttributeError:
528            if size:
529                raise ValueError("size must only be set for file-like objects")
530            if binary_format:
531                input_type = bytes
532                type_name = 'byte strings'
533            else:
534                input_type = basestring
535                type_name = 'strings'
536
537            if isinstance(stream, basestring):
538                if not isinstance(stream, input_type):
539                    raise ValueError("The input must be %s" % type_name)
540                if not binary_format:
541                    if isinstance(stream, str):
542                        if not stream.endswith('\n'):
543                            stream += '\n'
544                    else:
545                        if not stream.endswith(b'\n'):
546                            stream += b'\n'
547
548                def chunks():
549                    yield stream
550
551            elif isinstance(stream, Iterable):
552
553                def chunks():
554                    for chunk in stream:
555                        if not isinstance(chunk, input_type):
556                            raise ValueError(
557                                "Input stream must consist of %s" % type_name)
558                        if isinstance(chunk, str):
559                            if not chunk.endswith('\n'):
560                                chunk += '\n'
561                        else:
562                            if not chunk.endswith(b'\n'):
563                                chunk += b'\n'
564                        yield chunk
565
566            else:
567                raise TypeError("Need an input stream to copy from")
568        else:
569            if size is None:
570                size = 8192
571            elif not isinstance(size, int):
572                raise TypeError("The size option must be an integer")
573            if size > 0:
574
575                def chunks():
576                    while True:
577                        buffer = read(size)
578                        yield buffer
579                        if not buffer or len(buffer) < size:
580                            break
581
582            else:
583
584                def chunks():
585                    yield read()
586
587        if not table or not isinstance(table, basestring):
588            raise TypeError("Need a table to copy to")
589        if table.lower().startswith('select'):
590                raise ValueError("Must specify a table, not a query")
591        else:
592            table = '"%s"' % (table,)
593        operation = ['copy %s' % (table,)]
594        options = []
595        params = []
596        if format is not None:
597            if not isinstance(format, basestring):
598                raise TypeError("format option must be be a string")
599            if format not in ('text', 'csv', 'binary'):
600                raise ValueError("invalid format")
601            options.append('format %s' % (format,))
602        if sep is not None:
603            if not isinstance(sep, basestring):
604                raise TypeError("sep option must be a string")
605            if format == 'binary':
606                raise ValueError("sep is not allowed with binary format")
607            if len(sep) != 1:
608                raise ValueError("sep must be a single one-byte character")
609            options.append('delimiter %s')
610            params.append(sep)
611        if null is not None:
612            if not isinstance(null, basestring):
613                raise TypeError("null option must be a string")
614            options.append('null %s')
615            params.append(null)
616        if columns:
617            if not isinstance(columns, basestring):
618                columns = ','.join('"%s"' % (col,) for col in columns)
619            operation.append('(%s)' % (columns,))
620        operation.append("from stdin")
621        if options:
622            operation.append('(%s)' % ','.join(options))
623        operation = ' '.join(operation)
624
625        putdata = self._src.putdata
626        self.execute(operation, params)
627
628        try:
629            for chunk in chunks():
630                putdata(chunk)
631        except BaseException as error:
632            self.rowcount = -1
633            # the following call will re-raise the error
634            putdata(error)
635        else:
636            self.rowcount = putdata(None)
637
638        # return the cursor object, so you can chain operations
639        return self
640
641    def copy_to(self, stream, table,
642            format=None, sep=None, null=None, decode=None, columns=None):
643        """Copy data from the specified table to an output stream.
644
645        The output stream can be a file-like object with a write() method or
646        it can also be None, in which case the method will return a generator
647        yielding a row on each iteration.
648
649        Output will be returned as byte strings unless you set decode to true.
650
651        Note that you can also use a select query instead of the table name.
652
653        The format must be text, csv or binary. The sep option sets the
654        column separator (delimiter) used in the non binary formats.
655        The null option sets the textual representation of NULL in the output.
656
657        The copy operation can be restricted to a subset of columns. If no
658        columns are specified, all of them will be copied.
659        """
660        binary_format = format == 'binary'
661        if stream is not None:
662            try:
663                write = stream.write
664            except AttributeError:
665                raise TypeError("need an output stream to copy to")
666        if not table or not isinstance(table, basestring):
667            raise TypeError("need a table to copy to")
668        if table.lower().startswith('select'):
669            if columns:
670                raise ValueError("columns must be specified in the query")
671            table = '(%s)' % (table,)
672        else:
673            table = '"%s"' % (table,)
674        operation = ['copy %s' % (table,)]
675        options = []
676        params = []
677        if format is not None:
678            if not isinstance(format, basestring):
679                raise TypeError("format option must be a string")
680            if format not in ('text', 'csv', 'binary'):
681                raise ValueError("invalid format")
682            options.append('format %s' % (format,))
683        if sep is not None:
684            if not isinstance(sep, basestring):
685                raise TypeError("sep option must be a string")
686            if binary_format:
687                raise ValueError("sep is not allowed with binary format")
688            if len(sep) != 1:
689                raise ValueError("sep must be a single one-byte character")
690            options.append('delimiter %s')
691            params.append(sep)
692        if null is not None:
693            if not isinstance(null, basestring):
694                raise TypeError("null option must be a string")
695            options.append('null %s')
696            params.append(null)
697        if decode is None:
698            if format == 'binary':
699                decode = False
700            else:
701                decode = str is unicode
702        else:
703            if not isinstance(decode, (int, bool)):
704                raise TypeError("decode option must be a boolean")
705            if decode and binary_format:
706                raise ValueError("decode is not allowed with binary format")
707        if columns:
708            if not isinstance(columns, basestring):
709                columns = ','.join('"%s"' % (col,) for col in columns)
710            operation.append('(%s)' % (columns,))
711
712        operation.append("to stdout")
713        if options:
714            operation.append('(%s)' % ','.join(options))
715        operation = ' '.join(operation)
716
717        getdata = self._src.getdata
718        self.execute(operation, params)
719
720        def copy():
721            self.rowcount = 0
722            while True:
723                row = getdata(decode)
724                if isinstance(row, int):
725                    if self.rowcount != row:
726                        self.rowcount = row
727                    break
728                self.rowcount += 1
729                yield row
730
731        if stream is None:
732            # no input stream, return the generator
733            return copy()
734
735        # write the rows to the file-like input stream
736        for row in copy():
737            write(row)
738
739        # return the cursor object, so you can chain operations
740        return self
741
742    def __next__(self):
743        """Return the next row (support for the iteration protocol)."""
744        res = self.fetchone()
745        if res is None:
746            raise StopIteration
747        return res
748
749    # Note that since Python 2.6 the iterator protocol uses __next()__
750    # instead of next(), we keep it only for backward compatibility of pgdb.
751    next = __next__
752
753    @staticmethod
754    def nextset():
755        """Not supported."""
756        raise NotSupportedError("nextset() is not supported")
757
758    @staticmethod
759    def setinputsizes(sizes):
760        """Not supported."""
761        pass  # unsupported, but silently passed
762
763    @staticmethod
764    def setoutputsize(size, column=0):
765        """Not supported."""
766        pass  # unsupported, but silently passed
767
768    @staticmethod
769    def row_factory(row):
770        """Process rows before they are returned.
771
772        You can overwrite this statically with a custom row factory, or
773        you can build a row factory dynamically with build_row_factory().
774
775        For example, you can create a Cursor class that returns rows as
776        Python dictionaries like this:
777
778            class DictCursor(pgdb.Cursor):
779
780                def row_factory(self, row):
781                    return {desc[0]: value
782                        for desc, value in zip(self.description, row)}
783
784            cur = DictCursor(con)  # get one DictCursor instance or
785            con.cursor_type = DictCursor  # always use DictCursor instances
786        """
787        raise NotImplementedError
788
789    def build_row_factory(self):
790        """Build a row factory based on the current description.
791
792        This implementation builds a row factory for creating named tuples.
793        You can overwrite this method if you want to dynamically create
794        different row factories whenever the column description changes.
795        """
796        colnames = self.colnames
797        if colnames:
798            try:
799                try:
800                    return namedtuple('Row', colnames, rename=True)._make
801                except TypeError:  # Python 2.6 and 3.0 do not support rename
802                    colnames = [v if v.isalnum() else 'column_%d' % n
803                             for n, v in enumerate(colnames)]
804                    return namedtuple('Row', colnames)._make
805            except ValueError:  # there is still a problem with the field names
806                colnames = ['column_%d' % n for n in range(len(colnames))]
807                return namedtuple('Row', colnames)._make
808
809
810CursorDescription = namedtuple('CursorDescription',
811    ['name', 'type_code', 'display_size', 'internal_size',
812     'precision', 'scale', 'null_ok'])
813
814
815### Connection Objects
816
817class Connection(object):
818    """Connection object."""
819
820    # expose the exceptions as attributes on the connection object
821    Error = Error
822    Warning = Warning
823    InterfaceError = InterfaceError
824    DatabaseError = DatabaseError
825    InternalError = InternalError
826    OperationalError = OperationalError
827    ProgrammingError = ProgrammingError
828    IntegrityError = IntegrityError
829    DataError = DataError
830    NotSupportedError = NotSupportedError
831
832    def __init__(self, cnx):
833        """Create a database connection object."""
834        self._cnx = cnx  # connection
835        self._tnx = False  # transaction state
836        self.type_cache = TypeCache(cnx)
837        self.cursor_type = Cursor
838        try:
839            self._cnx.source()
840        except Exception:
841            raise _op_error("invalid connection")
842
843    def __enter__(self):
844        """Enter the runtime context for the connection object.
845
846        The runtime context can be used for running transactions.
847        """
848        return self
849
850    def __exit__(self, et, ev, tb):
851        """Exit the runtime context for the connection object.
852
853        This does not close the connection, but it ends a transaction.
854        """
855        if et is None and ev is None and tb is None:
856            self.commit()
857        else:
858            self.rollback()
859
860    def close(self):
861        """Close the connection object."""
862        if self._cnx:
863            if self._tnx:
864                try:
865                    self.rollback()
866                except DatabaseError:
867                    pass
868            self._cnx.close()
869            self._cnx = None
870        else:
871            raise _op_error("connection has been closed")
872
873    def commit(self):
874        """Commit any pending transaction to the database."""
875        if self._cnx:
876            if self._tnx:
877                self._tnx = False
878                try:
879                    self._cnx.source().execute("COMMIT")
880                except DatabaseError:
881                    raise
882                except Exception:
883                    raise _op_error("can't commit")
884        else:
885            raise _op_error("connection has been closed")
886
887    def rollback(self):
888        """Roll back to the start of any pending transaction."""
889        if self._cnx:
890            if self._tnx:
891                self._tnx = False
892                try:
893                    self._cnx.source().execute("ROLLBACK")
894                except DatabaseError:
895                    raise
896                except Exception:
897                    raise _op_error("can't rollback")
898        else:
899            raise _op_error("connection has been closed")
900
901    def cursor(self):
902        """Return a new cursor object using the connection."""
903        if self._cnx:
904            try:
905                return self.cursor_type(self)
906            except Exception:
907                raise _op_error("invalid connection")
908        else:
909            raise _op_error("connection has been closed")
910
911    if shortcutmethods:  # otherwise do not implement and document this
912
913        def execute(self, operation, params=None):
914            """Shortcut method to run an operation on an implicit cursor."""
915            cursor = self.cursor()
916            cursor.execute(operation, params)
917            return cursor
918
919        def executemany(self, operation, param_seq):
920            """Shortcut method to run an operation against a sequence."""
921            cursor = self.cursor()
922            cursor.executemany(operation, param_seq)
923            return cursor
924
925
926### Module Interface
927
928_connect_ = connect
929
930def connect(dsn=None,
931        user=None, password=None,
932        host=None, database=None):
933    """Connect to a database."""
934    # first get params from DSN
935    dbport = -1
936    dbhost = ""
937    dbbase = ""
938    dbuser = ""
939    dbpasswd = ""
940    dbopt = ""
941    try:
942        params = dsn.split(":")
943        dbhost = params[0]
944        dbbase = params[1]
945        dbuser = params[2]
946        dbpasswd = params[3]
947        dbopt = params[4]
948    except (AttributeError, IndexError, TypeError):
949        pass
950
951    # override if necessary
952    if user is not None:
953        dbuser = user
954    if password is not None:
955        dbpasswd = password
956    if database is not None:
957        dbbase = database
958    if host is not None:
959        try:
960            params = host.split(":")
961            dbhost = params[0]
962            dbport = int(params[1])
963        except (AttributeError, IndexError, TypeError, ValueError):
964            pass
965
966    # empty host is localhost
967    if dbhost == "":
968        dbhost = None
969    if dbuser == "":
970        dbuser = None
971
972    # open the connection
973    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
974    return Connection(cnx)
975
976
977### Types Handling
978
979class Type(frozenset):
980    """Type class for a couple of PostgreSQL data types.
981
982    PostgreSQL is object-oriented: types are dynamic.
983    We must thus use type names as internal type codes.
984    """
985
986    def __new__(cls, values):
987        if isinstance(values, basestring):
988            values = values.split()
989        return super(Type, cls).__new__(cls, values)
990
991    def __eq__(self, other):
992        if isinstance(other, basestring):
993            if other.startswith('_'):
994                other = other[1:]
995            return other in self
996        else:
997            return super(Type, self).__eq__(other)
998
999    def __ne__(self, other):
1000        if isinstance(other, basestring):
1001            if other.startswith('_'):
1002                other = other[1:]
1003            return other not in self
1004        else:
1005            return super(Type, self).__ne__(other)
1006
1007
1008class ArrayType:
1009    """Type class for PostgreSQL array types."""
1010
1011    def __eq__(self, other):
1012        if isinstance(other, basestring):
1013            return other.startswith('_')
1014        else:
1015            return isinstance(other, ArrayType)
1016
1017    def __ne__(self, other):
1018        if isinstance(other, basestring):
1019            return not other.startswith('_')
1020        else:
1021            return not isinstance(other, ArrayType)
1022
1023
1024# Mandatory type objects defined by DB-API 2 specs:
1025
1026STRING = Type('char bpchar name text varchar')
1027BINARY = Type('bytea')
1028NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1029DATETIME = Type('date time timetz timestamp timestamptz interval'
1030    ' abstime reltime')  # these are very old
1031ROWID = Type('oid')
1032
1033
1034# Additional type objects (more specific):
1035
1036BOOL = Type('bool')
1037SMALLINT = Type('int2')
1038INTEGER = Type('int2 int4 int8 serial')
1039LONG = Type('int8')
1040FLOAT = Type('float4 float8')
1041NUMERIC = Type('numeric')
1042MONEY = Type('money')
1043DATE = Type('date')
1044TIME = Type('time timetz')
1045TIMESTAMP = Type('timestamp timestamptz')
1046INTERVAL = Type('interval')
1047JSON = Type('json jsonb')
1048
1049# Type object for arrays (also equate to their base types):
1050
1051ARRAY = ArrayType()
1052
1053
1054# Mandatory type helpers defined by DB-API 2 specs:
1055
1056def Date(year, month, day):
1057    """Construct an object holding a date value."""
1058    return date(year, month, day)
1059
1060
1061def Time(hour, minute=0, second=0, microsecond=0):
1062    """Construct an object holding a time value."""
1063    return time(hour, minute, second, microsecond)
1064
1065
1066def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1067    """Construct an object holding a time stamp value."""
1068    return datetime(year, month, day, hour, minute, second, microsecond)
1069
1070
1071def DateFromTicks(ticks):
1072    """Construct an object holding a date value from the given ticks value."""
1073    return Date(*localtime(ticks)[:3])
1074
1075
1076def TimeFromTicks(ticks):
1077    """Construct an object holding a time value from the given ticks value."""
1078    return Time(*localtime(ticks)[3:6])
1079
1080
1081def TimestampFromTicks(ticks):
1082    """Construct an object holding a time stamp from the given ticks value."""
1083    return Timestamp(*localtime(ticks)[:6])
1084
1085
1086class Binary(bytes):
1087    """Construct an object capable of holding a binary (long) string value."""
1088
1089
1090# Additional type helpers for PyGreSQL:
1091
1092class Json:
1093    """Construct a wrapper for holding an object serializable to JSON."""
1094
1095    def __init__(self, obj, encode=None):
1096        self.obj = obj
1097        self.encode = encode or jsonencode
1098
1099    def __str__(self):
1100        obj = self.obj
1101        if isinstance(obj, basestring):
1102            return obj
1103        return self.encode(obj)
1104
1105    __pg_repr__ = __str__
1106
1107
1108# If run as script, print some information:
1109
1110if __name__ == '__main__':
1111    print('PyGreSQL version', version)
1112    print('')
1113    print(__doc__)
Note: See TracBrowser for help on using the repository browser.