source: trunk/pgdb.py @ 792

Last change on this file since 792 was 792, checked in by cito, 3 years ago

Using ARRAY and ROW constructor in pgdb again

Using the special input syntax for quoting arrays and rows had some
advantages, but one big disadvantage, namely the missing type information.
Therefore, this change has been reverted, we now use ARRAY and ROW
constructor syntax again to quote lists and tuples. See comments.

The code has become simpler again and doesn't need the re module any more.

  • Property svn:keywords set to Id
File size: 37.1 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 792 2016-01-28 19:54:34Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from json import loads as jsondecode, dumps as jsonencode
76
77try:
78    long
79except NameError:  # Python >= 3.0
80    long = int
81
82try:
83    unicode
84except NameError:  # Python >= 3.0
85    unicode = str
86
87try:
88    basestring
89except NameError:  # Python >= 3.0
90    basestring = (str, bytes)
91
92from collections import Iterable
93try:
94    from collections import OrderedDict
95except ImportError:  # Python 2.6 or 3.0
96    try:
97        from ordereddict import OrderedDict
98    except Exception:
99        def OrderedDict(*args):
100            raise NotSupportedError('OrderedDict is not supported')
101
102
103### Module Constants
104
105# compliant with DB API 2.0
106apilevel = '2.0'
107
108# module may be shared, but not connections
109threadsafety = 1
110
111# this module use extended python format codes
112paramstyle = 'pyformat'
113
114# shortcut methods have been excluded from DB API 2 and
115# are not recommended by the DB SIG, but they can be handy
116shortcutmethods = 1
117
118
119### Internal Types Handling
120
121def decimal_type(decimal_type=None):
122    """Get or set global type to be used for decimal values."""
123    global Decimal
124    if decimal_type is not None:
125        _cast['numeric'] = Decimal = decimal_type
126    return Decimal
127
128
129def _cast_bool(value):
130    return value[:1] in ('t', 'T')
131
132
133def _cast_money(value):
134    return Decimal(''.join(filter(
135        lambda v: v in '0123456789.-', value)))
136
137
138_cast = {'char': str, 'bpchar': str, 'name': str,
139    'text': str, 'varchar': str,
140    'bool': _cast_bool, 'bytea': unescape_bytea,
141    'int2': int, 'int4': int, 'serial': int,
142    'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
143    'oid': long, 'oid8': long,
144    'float4': float, 'float8': float,
145    'numeric': Decimal, 'money': _cast_money,
146    'record': cast_record}
147
148
149def _db_error(msg, cls=DatabaseError):
150    """Return DatabaseError with empty sqlstate attribute."""
151    error = cls(msg)
152    error.sqlstate = None
153    return error
154
155
156def _op_error(msg):
157    """Return OperationalError."""
158    return _db_error(msg, OperationalError)
159
160
161TypeInfo = namedtuple('TypeInfo',
162    ['oid', 'name', 'len', 'type', 'category', 'delim', 'relid'])
163
164ColumnInfo = namedtuple('ColumnInfo', ['name', 'type'])
165
166
167class TypeCache(dict):
168    """Cache for database types.
169
170    This cache maps type OIDs and names to TypeInfo tuples containing
171    important information on the associated database type.
172    """
173
174    def __init__(self, cnx):
175        """Initialize type cache for connection."""
176        super(TypeCache, self).__init__()
177        self._escape_string = cnx.escape_string
178        self._src = cnx.source()
179
180    def __missing__(self, key):
181        """Get the type info from the database if it is not cached."""
182        if isinstance(key, int):
183            oid = key
184        else:
185            if not '.' in key and not '"' in key:
186                key = '"%s"' % key
187            oid = "'%s'::regtype" % self._escape_string(key)
188        try:
189            self._src.execute("SELECT oid, typname,"
190                 " typlen, typtype, typcategory, typdelim, typrelid"
191                " FROM pg_type WHERE oid=%s" % oid)
192        except ProgrammingError:
193            res = None
194        else:
195            res = self._src.fetch(1)
196        if not res:
197            raise KeyError('Type %s could not be found' % key)
198        res = list(res[0])
199        res[0] = int(res[0])
200        res[2] = int(res[2])
201        res[6] = int(res[6])
202        res = TypeInfo(*res)
203        self[res.oid] = self[res.name] = res
204        return res
205
206    def get(self, key, default=None):
207        """Get the type even if it is not cached."""
208        try:
209            return self[key]
210        except KeyError:
211            return default
212
213    def columns(self, key):
214        """Get the names and types of the columns of composite types."""
215        try:
216            typ = self[key]
217        except KeyError:
218            return None  # this type is not known
219        if typ.type != 'c' or not typ.relid:
220            return None  # this type is not composite
221        self._src.execute("SELECT attname, atttypid"
222            " FROM pg_attribute WHERE attrelid=%s AND attnum>0"
223            " AND NOT attisdropped ORDER BY attnum" % typ.relid)
224        return [ColumnInfo(name, int(oid))
225            for name, oid in self._src.fetch(-1)]
226
227    def typecast(self, typ, value):
228        """Cast value according to database type."""
229        if value is None:
230            # for NULL values, no typecast is necessary
231            return None
232        cast = _cast.get(typ)
233        if cast is str:
234            return value  # no typecast necessary
235        if cast is None:
236            if typ.startswith('_'):
237                # cast as an array type
238                cast = _cast.get(typ[1:])
239                return cast_array(value, cast)
240            # check whether this is a composite type
241            cols = self.columns(typ)
242            if cols:
243                getcast = self.getcast
244                cast = [getcast(col.type) for col in cols]
245                value = cast_record(value, cast)
246                fields = [col.name for col in cols]
247                record = namedtuple(typ, fields)
248                return record(*value)
249            return value  # no typecast available or necessary
250        else:
251            return cast(value)
252
253    def getcast(self, key):
254        """Get a cast function for the given database type."""
255        if isinstance(key, int):
256            try:
257                typ = self[key].name
258            except KeyError:
259                return None
260        else:
261            typ = key
262        typecast = self.typecast
263        return lambda value: typecast(typ, value)
264
265
266class _quotedict(dict):
267    """Dictionary with auto quoting of its items.
268
269    The quote attribute must be set to the desired quote function.
270    """
271
272    def __getitem__(self, key):
273        return self.quote(super(_quotedict, self).__getitem__(key))
274
275
276### Cursor Object
277
278class Cursor(object):
279    """Cursor object."""
280
281    def __init__(self, dbcnx):
282        """Create a cursor object for the database connection."""
283        self.connection = self._dbcnx = dbcnx
284        self._cnx = dbcnx._cnx
285        self.type_cache = dbcnx.type_cache
286        self._src = self._cnx.source()
287        # the official attribute for describing the result columns
288        self._description = None
289        if self.row_factory is Cursor.row_factory:
290            # the row factory needs to be determined dynamically
291            self.row_factory = None
292        else:
293            self.build_row_factory = None
294        self.rowcount = -1
295        self.arraysize = 1
296        self.lastrowid = None
297
298    def __iter__(self):
299        """Make cursor compatible to the iteration protocol."""
300        return self
301
302    def __enter__(self):
303        """Enter the runtime context for the cursor object."""
304        return self
305
306    def __exit__(self, et, ev, tb):
307        """Exit the runtime context for the cursor object."""
308        self.close()
309
310    def _quote(self, val):
311        """Quote value depending on its type."""
312        if val is None:
313            return 'NULL'
314        if isinstance(val, (datetime, date, time, timedelta, Json)):
315            val = str(val)
316        if isinstance(val, basestring):
317            if isinstance(val, Binary):
318                val = self._cnx.escape_bytea(val)
319                if bytes is not str:  # Python >= 3.0
320                    val = val.decode('ascii')
321            else:
322                val = self._cnx.escape_string(val)
323            return "'%s'" % val
324        if isinstance(val, float):
325            if isinf(val):
326                return "'-Infinity'" if val < 0 else "'Infinity'"
327            if isnan(val):
328                return "'NaN'"
329            return val
330        if isinstance(val, (int, long, Decimal)):
331            return val
332        if isinstance(val, list):
333            # Quote value as an ARRAY constructor. This is better than using
334            # an array literal because it carries the information that this is
335            # an array and not a string.  One issue with this syntax is that
336            # you need to add an explicit type cast when passing empty arrays.
337            # The ARRAY keyword is actually only necessary at the top level.
338            q = self._quote
339            return 'ARRAY[%s]' % ','.join(str(q(v)) for v in val)
340        if isinstance(val, tuple):
341            # Quote as a ROW constructor.  This is better than using a record
342            # literal because it carries the information that this is a record
343            # and not a string.  We don't use the keyword ROW in order to make
344            # this usable with the IN synntax as well.  It is only necessary
345            # when the records has a single column which is not really useful.
346            q = self._quote
347            return '(%s)' % ','.join(str(q(v)) for v in val)
348        try:
349            return val.__pg_repr__()
350        except AttributeError:
351            raise InterfaceError(
352                'do not know how to handle type %s' % type(val))
353
354
355    def _quoteparams(self, string, parameters):
356        """Quote parameters.
357
358        This function works for both mappings and sequences.
359        """
360        if isinstance(parameters, dict):
361            parameters = _quotedict(parameters)
362            parameters.quote = self._quote
363        else:
364            parameters = tuple(map(self._quote, parameters))
365        return string % parameters
366
367    def _make_description(self, info):
368        """Make the description tuple for the given field info."""
369        name, typ, size, mod = info[1:]
370        type_info = self.type_cache[typ]
371        type_code = type_info.name
372        if mod > 0:
373            mod -= 4
374        if type_code == 'numeric':
375            precision, scale = mod >> 16, mod & 0xffff
376            size = precision
377        else:
378            if not size:
379                size = type_info.size
380            if size == -1:
381                size = mod
382            precision = scale = None
383        return CursorDescription(name, type_code,
384            None, size, precision, scale, None)
385
386    @property
387    def description(self):
388        """Read-only attribute describing the result columns."""
389        descr = self._description
390        if self._description is True:
391            make = self._make_description
392            descr = [make(info) for info in self._src.listinfo()]
393            self._description = descr
394        return descr
395
396    @property
397    def colnames(self):
398        """Unofficial convenience method for getting the column names."""
399        return [d[0] for d in self.description]
400
401    @property
402    def coltypes(self):
403        """Unofficial convenience method for getting the column types."""
404        return [d[1] for d in self.description]
405
406    def close(self):
407        """Close the cursor object."""
408        self._src.close()
409        self._description = None
410        self.rowcount = -1
411        self.lastrowid = None
412
413    def execute(self, operation, parameters=None):
414        """Prepare and execute a database operation (query or command)."""
415        # The parameters may also be specified as list of tuples to e.g.
416        # insert multiple rows in a single operation, but this kind of
417        # usage is deprecated.  We make several plausibility checks because
418        # tuples can also be passed with the meaning of ROW constructors.
419        if (parameters and isinstance(parameters, list)
420                and len(parameters) > 1
421                and all(isinstance(p, tuple) for p in parameters)
422                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
423            return self.executemany(operation, parameters)
424        else:
425            # not a list of tuples
426            return self.executemany(operation, [parameters])
427
428    def executemany(self, operation, seq_of_parameters):
429        """Prepare operation and execute it against a parameter sequence."""
430        if not seq_of_parameters:
431            # don't do anything without parameters
432            return
433        self._description = None
434        self.rowcount = -1
435        # first try to execute all queries
436        rowcount = 0
437        sql = "BEGIN"
438        try:
439            if not self._dbcnx._tnx:
440                try:
441                    self._cnx.source().execute(sql)
442                except DatabaseError:
443                    raise  # database provides error message
444                except Exception as err:
445                    raise _op_error("can't start transaction")
446                self._dbcnx._tnx = True
447            for parameters in seq_of_parameters:
448                sql = operation
449                if parameters:
450                    sql = self._quoteparams(sql, parameters)
451                rows = self._src.execute(sql)
452                if rows:  # true if not DML
453                    rowcount += rows
454                else:
455                    self.rowcount = -1
456        except DatabaseError:
457            raise  # database provides error message
458        except Error as err:
459            raise _db_error(
460                "error in '%s': '%s' " % (sql, err), InterfaceError)
461        except Exception as err:
462            raise _op_error("internal error in '%s': %s" % (sql, err))
463        # then initialize result raw count and description
464        if self._src.resulttype == RESULT_DQL:
465            self._description = True  # fetch on demand
466            self.rowcount = self._src.ntuples
467            self.lastrowid = None
468            if self.build_row_factory:
469                self.row_factory = self.build_row_factory()
470        else:
471            self.rowcount = rowcount
472            self.lastrowid = self._src.oidstatus()
473        # return the cursor object, so you can write statements such as
474        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
475        return self
476
477    def fetchone(self):
478        """Fetch the next row of a query result set."""
479        res = self.fetchmany(1, False)
480        try:
481            return res[0]
482        except IndexError:
483            return None
484
485    def fetchall(self):
486        """Fetch all (remaining) rows of a query result."""
487        return self.fetchmany(-1, False)
488
489    def fetchmany(self, size=None, keep=False):
490        """Fetch the next set of rows of a query result.
491
492        The number of rows to fetch per call is specified by the
493        size parameter. If it is not given, the cursor's arraysize
494        determines the number of rows to be fetched. If you set
495        the keep parameter to true, this is kept as new arraysize.
496        """
497        if size is None:
498            size = self.arraysize
499        if keep:
500            self.arraysize = size
501        try:
502            result = self._src.fetch(size)
503        except DatabaseError:
504            raise
505        except Error as err:
506            raise _db_error(str(err))
507        typecast = self.type_cache.typecast
508        return [self.row_factory([typecast(typ, value)
509            for typ, value in zip(self.coltypes, row)]) for row in result]
510
511    def callproc(self, procname, parameters=None):
512        """Call a stored database procedure with the given name.
513
514        The sequence of parameters must contain one entry for each input
515        argument that the procedure expects. The result of the call is the
516        same as this input sequence; replacement of output and input/output
517        parameters in the return value is currently not supported.
518
519        The procedure may also provide a result set as output. These can be
520        requested through the standard fetch methods of the cursor.
521        """
522        n = parameters and len(parameters) or 0
523        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
524        self.execute(query, parameters)
525        return parameters
526
527    def copy_from(self, stream, table,
528            format=None, sep=None, null=None, size=None, columns=None):
529        """Copy data from an input stream to the specified table.
530
531        The input stream can be a file-like object with a read() method or
532        it can also be an iterable returning a row or multiple rows of input
533        on each iteration.
534
535        The format must be text, csv or binary. The sep option sets the
536        column separator (delimiter) used in the non binary formats.
537        The null option sets the textual representation of NULL in the input.
538
539        The size option sets the size of the buffer used when reading data
540        from file-like objects.
541
542        The copy operation can be restricted to a subset of columns. If no
543        columns are specified, all of them will be copied.
544        """
545        binary_format = format == 'binary'
546        try:
547            read = stream.read
548        except AttributeError:
549            if size:
550                raise ValueError("size must only be set for file-like objects")
551            if binary_format:
552                input_type = bytes
553                type_name = 'byte strings'
554            else:
555                input_type = basestring
556                type_name = 'strings'
557
558            if isinstance(stream, basestring):
559                if not isinstance(stream, input_type):
560                    raise ValueError("The input must be %s" % type_name)
561                if not binary_format:
562                    if isinstance(stream, str):
563                        if not stream.endswith('\n'):
564                            stream += '\n'
565                    else:
566                        if not stream.endswith(b'\n'):
567                            stream += b'\n'
568
569                def chunks():
570                    yield stream
571
572            elif isinstance(stream, Iterable):
573
574                def chunks():
575                    for chunk in stream:
576                        if not isinstance(chunk, input_type):
577                            raise ValueError(
578                                "Input stream must consist of %s" % type_name)
579                        if isinstance(chunk, str):
580                            if not chunk.endswith('\n'):
581                                chunk += '\n'
582                        else:
583                            if not chunk.endswith(b'\n'):
584                                chunk += b'\n'
585                        yield chunk
586
587            else:
588                raise TypeError("Need an input stream to copy from")
589        else:
590            if size is None:
591                size = 8192
592            elif not isinstance(size, int):
593                raise TypeError("The size option must be an integer")
594            if size > 0:
595
596                def chunks():
597                    while True:
598                        buffer = read(size)
599                        yield buffer
600                        if not buffer or len(buffer) < size:
601                            break
602
603            else:
604
605                def chunks():
606                    yield read()
607
608        if not table or not isinstance(table, basestring):
609            raise TypeError("Need a table to copy to")
610        if table.lower().startswith('select'):
611                raise ValueError("Must specify a table, not a query")
612        else:
613            table = '"%s"' % (table,)
614        operation = ['copy %s' % (table,)]
615        options = []
616        params = []
617        if format is not None:
618            if not isinstance(format, basestring):
619                raise TypeError("format option must be be a string")
620            if format not in ('text', 'csv', 'binary'):
621                raise ValueError("invalid format")
622            options.append('format %s' % (format,))
623        if sep is not None:
624            if not isinstance(sep, basestring):
625                raise TypeError("sep option must be a string")
626            if format == 'binary':
627                raise ValueError("sep is not allowed with binary format")
628            if len(sep) != 1:
629                raise ValueError("sep must be a single one-byte character")
630            options.append('delimiter %s')
631            params.append(sep)
632        if null is not None:
633            if not isinstance(null, basestring):
634                raise TypeError("null option must be a string")
635            options.append('null %s')
636            params.append(null)
637        if columns:
638            if not isinstance(columns, basestring):
639                columns = ','.join('"%s"' % (col,) for col in columns)
640            operation.append('(%s)' % (columns,))
641        operation.append("from stdin")
642        if options:
643            operation.append('(%s)' % ','.join(options))
644        operation = ' '.join(operation)
645
646        putdata = self._src.putdata
647        self.execute(operation, params)
648
649        try:
650            for chunk in chunks():
651                putdata(chunk)
652        except BaseException as error:
653            self.rowcount = -1
654            # the following call will re-raise the error
655            putdata(error)
656        else:
657            self.rowcount = putdata(None)
658
659        # return the cursor object, so you can chain operations
660        return self
661
662    def copy_to(self, stream, table,
663            format=None, sep=None, null=None, decode=None, columns=None):
664        """Copy data from the specified table to an output stream.
665
666        The output stream can be a file-like object with a write() method or
667        it can also be None, in which case the method will return a generator
668        yielding a row on each iteration.
669
670        Output will be returned as byte strings unless you set decode to true.
671
672        Note that you can also use a select query instead of the table name.
673
674        The format must be text, csv or binary. The sep option sets the
675        column separator (delimiter) used in the non binary formats.
676        The null option sets the textual representation of NULL in the output.
677
678        The copy operation can be restricted to a subset of columns. If no
679        columns are specified, all of them will be copied.
680        """
681        binary_format = format == 'binary'
682        if stream is not None:
683            try:
684                write = stream.write
685            except AttributeError:
686                raise TypeError("need an output stream to copy to")
687        if not table or not isinstance(table, basestring):
688            raise TypeError("need a table to copy to")
689        if table.lower().startswith('select'):
690            if columns:
691                raise ValueError("columns must be specified in the query")
692            table = '(%s)' % (table,)
693        else:
694            table = '"%s"' % (table,)
695        operation = ['copy %s' % (table,)]
696        options = []
697        params = []
698        if format is not None:
699            if not isinstance(format, basestring):
700                raise TypeError("format option must be a string")
701            if format not in ('text', 'csv', 'binary'):
702                raise ValueError("invalid format")
703            options.append('format %s' % (format,))
704        if sep is not None:
705            if not isinstance(sep, basestring):
706                raise TypeError("sep option must be a string")
707            if binary_format:
708                raise ValueError("sep is not allowed with binary format")
709            if len(sep) != 1:
710                raise ValueError("sep must be a single one-byte character")
711            options.append('delimiter %s')
712            params.append(sep)
713        if null is not None:
714            if not isinstance(null, basestring):
715                raise TypeError("null option must be a string")
716            options.append('null %s')
717            params.append(null)
718        if decode is None:
719            if format == 'binary':
720                decode = False
721            else:
722                decode = str is unicode
723        else:
724            if not isinstance(decode, (int, bool)):
725                raise TypeError("decode option must be a boolean")
726            if decode and binary_format:
727                raise ValueError("decode is not allowed with binary format")
728        if columns:
729            if not isinstance(columns, basestring):
730                columns = ','.join('"%s"' % (col,) for col in columns)
731            operation.append('(%s)' % (columns,))
732
733        operation.append("to stdout")
734        if options:
735            operation.append('(%s)' % ','.join(options))
736        operation = ' '.join(operation)
737
738        getdata = self._src.getdata
739        self.execute(operation, params)
740
741        def copy():
742            self.rowcount = 0
743            while True:
744                row = getdata(decode)
745                if isinstance(row, int):
746                    if self.rowcount != row:
747                        self.rowcount = row
748                    break
749                self.rowcount += 1
750                yield row
751
752        if stream is None:
753            # no input stream, return the generator
754            return copy()
755
756        # write the rows to the file-like input stream
757        for row in copy():
758            write(row)
759
760        # return the cursor object, so you can chain operations
761        return self
762
763    def __next__(self):
764        """Return the next row (support for the iteration protocol)."""
765        res = self.fetchone()
766        if res is None:
767            raise StopIteration
768        return res
769
770    # Note that since Python 2.6 the iterator protocol uses __next()__
771    # instead of next(), we keep it only for backward compatibility of pgdb.
772    next = __next__
773
774    @staticmethod
775    def nextset():
776        """Not supported."""
777        raise NotSupportedError("nextset() is not supported")
778
779    @staticmethod
780    def setinputsizes(sizes):
781        """Not supported."""
782        pass  # unsupported, but silently passed
783
784    @staticmethod
785    def setoutputsize(size, column=0):
786        """Not supported."""
787        pass  # unsupported, but silently passed
788
789    @staticmethod
790    def row_factory(row):
791        """Process rows before they are returned.
792
793        You can overwrite this statically with a custom row factory, or
794        you can build a row factory dynamically with build_row_factory().
795
796        For example, you can create a Cursor class that returns rows as
797        Python dictionaries like this:
798
799            class DictCursor(pgdb.Cursor):
800
801                def row_factory(self, row):
802                    return {desc[0]: value
803                        for desc, value in zip(self.description, row)}
804
805            cur = DictCursor(con)  # get one DictCursor instance or
806            con.cursor_type = DictCursor  # always use DictCursor instances
807        """
808        raise NotImplementedError
809
810    def build_row_factory(self):
811        """Build a row factory based on the current description.
812
813        This implementation builds a row factory for creating named tuples.
814        You can overwrite this method if you want to dynamically create
815        different row factories whenever the column description changes.
816        """
817        colnames = self.colnames
818        if colnames:
819            try:
820                try:
821                    return namedtuple('Row', colnames, rename=True)._make
822                except TypeError:  # Python 2.6 and 3.0 do not support rename
823                    colnames = [v if v.isalnum() else 'column_%d' % n
824                             for n, v in enumerate(colnames)]
825                    return namedtuple('Row', colnames)._make
826            except ValueError:  # there is still a problem with the field names
827                colnames = ['column_%d' % n for n in range(len(colnames))]
828                return namedtuple('Row', colnames)._make
829
830
831CursorDescription = namedtuple('CursorDescription',
832    ['name', 'type_code', 'display_size', 'internal_size',
833     'precision', 'scale', 'null_ok'])
834
835
836### Connection Objects
837
838class Connection(object):
839    """Connection object."""
840
841    # expose the exceptions as attributes on the connection object
842    Error = Error
843    Warning = Warning
844    InterfaceError = InterfaceError
845    DatabaseError = DatabaseError
846    InternalError = InternalError
847    OperationalError = OperationalError
848    ProgrammingError = ProgrammingError
849    IntegrityError = IntegrityError
850    DataError = DataError
851    NotSupportedError = NotSupportedError
852
853    def __init__(self, cnx):
854        """Create a database connection object."""
855        self._cnx = cnx  # connection
856        self._tnx = False  # transaction state
857        self.type_cache = TypeCache(cnx)
858        self.cursor_type = Cursor
859        try:
860            self._cnx.source()
861        except Exception:
862            raise _op_error("invalid connection")
863
864    def __enter__(self):
865        """Enter the runtime context for the connection object.
866
867        The runtime context can be used for running transactions.
868        """
869        return self
870
871    def __exit__(self, et, ev, tb):
872        """Exit the runtime context for the connection object.
873
874        This does not close the connection, but it ends a transaction.
875        """
876        if et is None and ev is None and tb is None:
877            self.commit()
878        else:
879            self.rollback()
880
881    def close(self):
882        """Close the connection object."""
883        if self._cnx:
884            if self._tnx:
885                try:
886                    self.rollback()
887                except DatabaseError:
888                    pass
889            self._cnx.close()
890            self._cnx = None
891        else:
892            raise _op_error("connection has been closed")
893
894    def commit(self):
895        """Commit any pending transaction to the database."""
896        if self._cnx:
897            if self._tnx:
898                self._tnx = False
899                try:
900                    self._cnx.source().execute("COMMIT")
901                except DatabaseError:
902                    raise
903                except Exception:
904                    raise _op_error("can't commit")
905        else:
906            raise _op_error("connection has been closed")
907
908    def rollback(self):
909        """Roll back to the start of any pending transaction."""
910        if self._cnx:
911            if self._tnx:
912                self._tnx = False
913                try:
914                    self._cnx.source().execute("ROLLBACK")
915                except DatabaseError:
916                    raise
917                except Exception:
918                    raise _op_error("can't rollback")
919        else:
920            raise _op_error("connection has been closed")
921
922    def cursor(self):
923        """Return a new cursor object using the connection."""
924        if self._cnx:
925            try:
926                return self.cursor_type(self)
927            except Exception:
928                raise _op_error("invalid connection")
929        else:
930            raise _op_error("connection has been closed")
931
932    if shortcutmethods:  # otherwise do not implement and document this
933
934        def execute(self, operation, params=None):
935            """Shortcut method to run an operation on an implicit cursor."""
936            cursor = self.cursor()
937            cursor.execute(operation, params)
938            return cursor
939
940        def executemany(self, operation, param_seq):
941            """Shortcut method to run an operation against a sequence."""
942            cursor = self.cursor()
943            cursor.executemany(operation, param_seq)
944            return cursor
945
946
947### Module Interface
948
949_connect_ = connect
950
951def connect(dsn=None,
952        user=None, password=None,
953        host=None, database=None):
954    """Connect to a database."""
955    # first get params from DSN
956    dbport = -1
957    dbhost = ""
958    dbbase = ""
959    dbuser = ""
960    dbpasswd = ""
961    dbopt = ""
962    try:
963        params = dsn.split(":")
964        dbhost = params[0]
965        dbbase = params[1]
966        dbuser = params[2]
967        dbpasswd = params[3]
968        dbopt = params[4]
969    except (AttributeError, IndexError, TypeError):
970        pass
971
972    # override if necessary
973    if user is not None:
974        dbuser = user
975    if password is not None:
976        dbpasswd = password
977    if database is not None:
978        dbbase = database
979    if host is not None:
980        try:
981            params = host.split(":")
982            dbhost = params[0]
983            dbport = int(params[1])
984        except (AttributeError, IndexError, TypeError, ValueError):
985            pass
986
987    # empty host is localhost
988    if dbhost == "":
989        dbhost = None
990    if dbuser == "":
991        dbuser = None
992
993    # open the connection
994    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
995    return Connection(cnx)
996
997
998### Types Handling
999
1000class Type(frozenset):
1001    """Type class for a couple of PostgreSQL data types.
1002
1003    PostgreSQL is object-oriented: types are dynamic.
1004    We must thus use type names as internal type codes.
1005    """
1006
1007    def __new__(cls, values):
1008        if isinstance(values, basestring):
1009            values = values.split()
1010        return super(Type, cls).__new__(cls, values)
1011
1012    def __eq__(self, other):
1013        if isinstance(other, basestring):
1014            if other.startswith('_'):
1015                other = other[1:]
1016            return other in self
1017        else:
1018            return super(Type, self).__eq__(other)
1019
1020    def __ne__(self, other):
1021        if isinstance(other, basestring):
1022            if other.startswith('_'):
1023                other = other[1:]
1024            return other not in self
1025        else:
1026            return super(Type, self).__ne__(other)
1027
1028
1029class ArrayType:
1030    """Type class for PostgreSQL array types."""
1031
1032    def __eq__(self, other):
1033        if isinstance(other, basestring):
1034            return other.startswith('_')
1035        else:
1036            return isinstance(other, ArrayType)
1037
1038    def __ne__(self, other):
1039        if isinstance(other, basestring):
1040            return not other.startswith('_')
1041        else:
1042            return not isinstance(other, ArrayType)
1043
1044
1045# Mandatory type objects defined by DB-API 2 specs:
1046
1047STRING = Type('char bpchar name text varchar')
1048BINARY = Type('bytea')
1049NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
1050DATETIME = Type('date time timetz timestamp timestamptz interval'
1051    ' abstime reltime')  # these are very old
1052ROWID = Type('oid')
1053
1054
1055# Additional type objects (more specific):
1056
1057BOOL = Type('bool')
1058SMALLINT = Type('int2')
1059INTEGER = Type('int2 int4 int8 serial')
1060LONG = Type('int8')
1061FLOAT = Type('float4 float8')
1062NUMERIC = Type('numeric')
1063MONEY = Type('money')
1064DATE = Type('date')
1065TIME = Type('time timetz')
1066TIMESTAMP = Type('timestamp timestamptz')
1067INTERVAL = Type('interval')
1068JSON = Type('json jsonb')
1069
1070# Type object for arrays (also equate to their base types):
1071
1072ARRAY = ArrayType()
1073
1074
1075# Mandatory type helpers defined by DB-API 2 specs:
1076
1077def Date(year, month, day):
1078    """Construct an object holding a date value."""
1079    return date(year, month, day)
1080
1081
1082def Time(hour, minute=0, second=0, microsecond=0):
1083    """Construct an object holding a time value."""
1084    return time(hour, minute, second, microsecond)
1085
1086
1087def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
1088    """Construct an object holding a time stamp value."""
1089    return datetime(year, month, day, hour, minute, second, microsecond)
1090
1091
1092def DateFromTicks(ticks):
1093    """Construct an object holding a date value from the given ticks value."""
1094    return Date(*localtime(ticks)[:3])
1095
1096
1097def TimeFromTicks(ticks):
1098    """Construct an object holding a time value from the given ticks value."""
1099    return Time(*localtime(ticks)[3:6])
1100
1101
1102def TimestampFromTicks(ticks):
1103    """Construct an object holding a time stamp from the given ticks value."""
1104    return Timestamp(*localtime(ticks)[:6])
1105
1106
1107class Binary(bytes):
1108    """Construct an object capable of holding a binary (long) string value."""
1109
1110
1111# Additional type helpers for PyGreSQL:
1112
1113class Json:
1114    """Construct a wrapper for holding an object serializable to JSON."""
1115
1116    def __init__(self, obj, encode=None):
1117        self.obj = obj
1118        self.encode = encode or jsonencode
1119
1120    def __str__(self):
1121        obj = self.obj
1122        if isinstance(obj, basestring):
1123            return obj
1124        return self.encode(obj)
1125
1126    __pg_repr__ = __str__
1127
1128
1129# If run as script, print some information:
1130
1131if __name__ == '__main__':
1132    print('PyGreSQL version', version)
1133    print('')
1134    print(__doc__)
Note: See TracBrowser for help on using the repository browser.