source: trunk/pgdb.py @ 772

Last change on this file since 772 was 772, checked in by cito, 4 years ago

Improve test coverage for the pgdb module

Includes a simple patch that allows storing Python lists or tuple values
in PostgreSQL array fields (they are not yet converted when read, though).

Also re-activated the shortcut methods on the connection again
since they can be sometimes useful.

Test coverage is now around 95%, the remaining lines are due to support for
old Python versions or obscure database errors that can't easily be aroused.

  • Property svn:keywords set to Id
File size: 31.9 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 772 2016-01-20 22:44:20Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75
76try:
77    long
78except NameError:  # Python >= 3.0
79    long = int
80
81try:
82    unicode
83except NameError:  # Python >= 3.0
84    unicode = str
85
86try:
87    basestring
88except NameError:  # Python >= 3.0
89    basestring = (str, bytes)
90
91from collections import Iterable
92try:
93    from collections import OrderedDict
94except ImportError:  # Python 2.6 or 3.0
95    try:
96        from ordereddict import OrderedDict
97    except Exception:
98        def OrderedDict(*args):
99            raise NotSupportedError('OrderedDict is not supported')
100
101set_decimal(Decimal)
102
103
104### Module Constants
105
106# compliant with DB API 2.0
107apilevel = '2.0'
108
109# module may be shared, but not connections
110threadsafety = 1
111
112# this module use extended python format codes
113paramstyle = 'pyformat'
114
115# shortcut methods have been excluded from DB API 2 and
116# are not recommended by the DB SIG, but they can be handy
117shortcutmethods = 1
118
119
120### Internal Types Handling
121
122def decimal_type(decimal_type=None):
123    """Get or set global type to be used for decimal values."""
124    global Decimal
125    if decimal_type is not None:
126        _cast['numeric'] = Decimal = decimal_type
127        set_decimal(Decimal)
128    return Decimal
129
130
131def _cast_bool(value):
132    return value[:1] in ('t', 'T')
133
134
135def _cast_money(value):
136    return Decimal(''.join(filter(
137        lambda v: v in '0123456789.-', value)))
138
139
140def _cast_bytea(value):
141    return unescape_bytea(value)
142
143
144def _cast_float(value):
145    return float(value)  # this also works with NaN and Infinity
146
147
148_cast = {'bool': _cast_bool, 'bytea': _cast_bytea,
149    'int2': int, 'int4': int, 'serial': int,
150    'int8': long, 'oid': long, 'oid8': long,
151    'float4': _cast_float, 'float8': _cast_float,
152    'numeric': Decimal, 'money': _cast_money}
153
154
155def _db_error(msg, cls=DatabaseError):
156    """Return DatabaseError with empty sqlstate attribute."""
157    error = cls(msg)
158    error.sqlstate = None
159    return error
160
161
162def _op_error(msg):
163    """Return OperationalError."""
164    return _db_error(msg, OperationalError)
165
166
167class TypeCache(dict):
168    """Cache for database types."""
169
170    def __init__(self, cnx):
171        """Initialize type cache for connection."""
172        super(TypeCache, self).__init__()
173        self._src = cnx.source()
174
175    @staticmethod
176    def typecast(typ, value):
177        """Cast value to database type."""
178        if value is None:
179            # for NULL values, no typecast is necessary
180            return None
181        cast = _cast.get(typ)
182        if cast is None:
183            # no typecast available or necessary
184            return value
185        else:
186            return cast(value)
187
188    def getdescr(self, oid):
189        """Get name of database type with given oid."""
190        try:
191            return self[oid]
192        except KeyError:
193            self._src.execute(
194                "SELECT typname, typlen "
195                "FROM pg_type WHERE oid=%s" % oid)
196            res = self._src.fetch(1)[0]
197            # The column name is omitted from the return value.
198            # It will have to be prepended by the caller.
199            res = (res[0], None, int(res[1]), None, None, None)
200            self[oid] = res
201            return res
202
203
204class _quotedict(dict):
205    """Dictionary with auto quoting of its items.
206
207    The quote attribute must be set to the desired quote function.
208    """
209
210    def __getitem__(self, key):
211        return self.quote(super(_quotedict, self).__getitem__(key))
212
213
214### Cursor Object
215
216class Cursor(object):
217    """Cursor object."""
218
219    def __init__(self, dbcnx):
220        """Create a cursor object for the database connection."""
221        self.connection = self._dbcnx = dbcnx
222        self._cnx = dbcnx._cnx
223        self._type_cache = dbcnx._type_cache
224        self._src = self._cnx.source()
225        # the official attribute for describing the result columns
226        self.description = None
227        # unofficial attributes for convenience and performance
228        self.colnames = self.coltypes = None
229        if self.row_factory is Cursor.row_factory:
230            # the row factory needs to be determined dynamically
231            self.row_factory = None
232        else:
233            self.build_row_factory = None
234        self.rowcount = -1
235        self.arraysize = 1
236        self.lastrowid = None
237
238    def __iter__(self):
239        """Make cursor compatible to the iteration protocol."""
240        return self
241
242    def __enter__(self):
243        """Enter the runtime context for the cursor object."""
244        return self
245
246    def __exit__(self, et, ev, tb):
247        """Exit the runtime context for the cursor object."""
248        self.close()
249
250    def _quote(self, val):
251        """Quote value depending on its type."""
252        if isinstance(val, (datetime, date, time, timedelta)):
253            val = str(val)
254        if isinstance(val, basestring):
255            if isinstance(val, Binary):
256                val = self._cnx.escape_bytea(val)
257                if bytes is not str:  # Python >= 3.0
258                    val = val.decode('ascii')
259            else:
260                val = self._cnx.escape_string(val)
261            val = "'%s'" % val
262        elif isinstance(val, (int, long)):
263            pass
264        elif isinstance(val, float):
265            if isinf(val):
266                return "'-Infinity'" if val < 0 else "'Infinity'"
267            elif isnan(val):
268                return "'NaN'"
269        elif val is None:
270            val = 'NULL'
271        elif isinstance(val, (list, tuple)):
272            q = self._quote
273            val = 'ARRAY[%s]' % ','.join(str(q(v)) for v in val)
274        elif Decimal is not float and isinstance(val, Decimal):
275            pass
276        elif hasattr(val, '__pg_repr__'):
277            val = val.__pg_repr__()
278        else:
279            raise InterfaceError(
280                'do not know how to handle type %s' % type(val))
281        return val
282
283    def _quoteparams(self, string, parameters):
284        """Quote parameters.
285
286        This function works for both mappings and sequences.
287        """
288        if isinstance(parameters, dict):
289            parameters = _quotedict(parameters)
290            parameters.quote = self._quote
291        else:
292            parameters = tuple(map(self._quote, parameters))
293        return string % parameters
294
295    def close(self):
296        """Close the cursor object."""
297        self._src.close()
298        self.description = None
299        self.colnames = self.coltypes = None
300        self.rowcount = -1
301        self.lastrowid = None
302
303    def execute(self, operation, parameters=None):
304        """Prepare and execute a database operation (query or command)."""
305
306        # The parameters may also be specified as list of
307        # tuples to e.g. insert multiple rows in a single
308        # operation, but this kind of usage is deprecated:
309        if (parameters and isinstance(parameters, list) and
310                isinstance(parameters[0], tuple)):
311            return self.executemany(operation, parameters)
312        else:
313            # not a list of tuples
314            return self.executemany(operation, [parameters])
315
316    def executemany(self, operation, seq_of_parameters):
317        """Prepare operation and execute it against a parameter sequence."""
318        if not seq_of_parameters:
319            # don't do anything without parameters
320            return
321        self.description = None
322        self.colnames = self.coltypes = None
323        self.rowcount = -1
324        # first try to execute all queries
325        rowcount = 0
326        sql = "BEGIN"
327        try:
328            if not self._dbcnx._tnx:
329                try:
330                    self._cnx.source().execute(sql)
331                except DatabaseError:
332                    raise  # database provides error message
333                except Exception as err:
334                    raise _op_error("can't start transaction")
335                self._dbcnx._tnx = True
336            for parameters in seq_of_parameters:
337                if parameters:
338                    sql = self._quoteparams(operation, parameters)
339                else:
340                    sql = operation
341                rows = self._src.execute(sql)
342                if rows:  # true if not DML
343                    rowcount += rows
344                else:
345                    self.rowcount = -1
346        except DatabaseError:
347            raise  # database provides error message
348        except Error as err:
349            raise _db_error(
350                "error in '%s': '%s' " % (sql, err), InterfaceError)
351        except Exception as err:
352            raise _op_error("internal error in '%s': %s" % (sql, err))
353        # then initialize result raw count and description
354        if self._src.resulttype == RESULT_DQL:
355            self.rowcount = self._src.ntuples
356            getdescr = self._type_cache.getdescr
357            description = [CursorDescription(
358                info[1], *getdescr(info[2])) for info in self._src.listinfo()]
359            self.colnames = [info[0] for info in description]
360            self.coltypes = [info[1] for info in description]
361            self.description = description
362            self.lastrowid = None
363            if self.build_row_factory:
364                self.row_factory = self.build_row_factory()
365        else:
366            self.rowcount = rowcount
367            self.lastrowid = self._src.oidstatus()
368        # return the cursor object, so you can write statements such as
369        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
370        return self
371
372    def fetchone(self):
373        """Fetch the next row of a query result set."""
374        res = self.fetchmany(1, False)
375        try:
376            return res[0]
377        except IndexError:
378            return None
379
380    def fetchall(self):
381        """Fetch all (remaining) rows of a query result."""
382        return self.fetchmany(-1, False)
383
384    def fetchmany(self, size=None, keep=False):
385        """Fetch the next set of rows of a query result.
386
387        The number of rows to fetch per call is specified by the
388        size parameter. If it is not given, the cursor's arraysize
389        determines the number of rows to be fetched. If you set
390        the keep parameter to true, this is kept as new arraysize.
391        """
392        if size is None:
393            size = self.arraysize
394        if keep:
395            self.arraysize = size
396        try:
397            result = self._src.fetch(size)
398        except DatabaseError:
399            raise
400        except Error as err:
401            raise _db_error(str(err))
402        typecast = self._type_cache.typecast
403        return [self.row_factory([typecast(typ, value)
404            for typ, value in zip(self.coltypes, row)]) for row in result]
405
406    def callproc(self, procname, parameters=None):
407        """Call a stored database procedure with the given name.
408
409        The sequence of parameters must contain one entry for each input
410        argument that the procedure expects. The result of the call is the
411        same as this input sequence; replacement of output and input/output
412        parameters in the return value is currently not supported.
413
414        The procedure may also provide a result set as output. These can be
415        requested through the standard fetch methods of the cursor.
416        """
417        n = parameters and len(parameters) or 0
418        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
419        self.execute(query, parameters)
420        return parameters
421
422    def copy_from(self, stream, table,
423            format=None, sep=None, null=None, size=None, columns=None):
424        """Copy data from an input stream to the specified table.
425
426        The input stream can be a file-like object with a read() method or
427        it can also be an iterable returning a row or multiple rows of input
428        on each iteration.
429
430        The format must be text, csv or binary. The sep option sets the
431        column separator (delimiter) used in the non binary formats.
432        The null option sets the textual representation of NULL in the input.
433
434        The size option sets the size of the buffer used when reading data
435        from file-like objects.
436
437        The copy operation can be restricted to a subset of columns. If no
438        columns are specified, all of them will be copied.
439        """
440        binary_format = format == 'binary'
441        try:
442            read = stream.read
443        except AttributeError:
444            if size:
445                raise ValueError("size must only be set for file-like objects")
446            if binary_format:
447                input_type = bytes
448                type_name = 'byte strings'
449            else:
450                input_type = basestring
451                type_name = 'strings'
452
453            if isinstance(stream, basestring):
454                if not isinstance(stream, input_type):
455                    raise ValueError("The input must be %s" % type_name)
456                if not binary_format:
457                    if isinstance(stream, str):
458                        if not stream.endswith('\n'):
459                            stream += '\n'
460                    else:
461                        if not stream.endswith(b'\n'):
462                            stream += b'\n'
463
464                def chunks():
465                    yield stream
466
467            elif isinstance(stream, Iterable):
468
469                def chunks():
470                    for chunk in stream:
471                        if not isinstance(chunk, input_type):
472                            raise ValueError(
473                                "Input stream must consist of %s" % type_name)
474                        if isinstance(chunk, str):
475                            if not chunk.endswith('\n'):
476                                chunk += '\n'
477                        else:
478                            if not chunk.endswith(b'\n'):
479                                chunk += b'\n'
480                        yield chunk
481
482            else:
483                raise TypeError("Need an input stream to copy from")
484        else:
485            if size is None:
486                size = 8192
487            elif not isinstance(size, int):
488                raise TypeError("The size option must be an integer")
489            if size > 0:
490
491                def chunks():
492                    while True:
493                        buffer = read(size)
494                        yield buffer
495                        if not buffer or len(buffer) < size:
496                            break
497
498            else:
499
500                def chunks():
501                    yield read()
502
503        if not table or not isinstance(table, basestring):
504            raise TypeError("Need a table to copy to")
505        if table.lower().startswith('select'):
506                raise ValueError("Must specify a table, not a query")
507        else:
508            table = '"%s"' % (table,)
509        operation = ['copy %s' % (table,)]
510        options = []
511        params = []
512        if format is not None:
513            if not isinstance(format, basestring):
514                raise TypeError("format option must be be a string")
515            if format not in ('text', 'csv', 'binary'):
516                raise ValueError("invalid format")
517            options.append('format %s' % (format,))
518        if sep is not None:
519            if not isinstance(sep, basestring):
520                raise TypeError("sep option must be a string")
521            if format == 'binary':
522                raise ValueError("sep is not allowed with binary format")
523            if len(sep) != 1:
524                raise ValueError("sep must be a single one-byte character")
525            options.append('delimiter %s')
526            params.append(sep)
527        if null is not None:
528            if not isinstance(null, basestring):
529                raise TypeError("null option must be a string")
530            options.append('null %s')
531            params.append(null)
532        if columns:
533            if not isinstance(columns, basestring):
534                columns = ','.join('"%s"' % (col,) for col in columns)
535            operation.append('(%s)' % (columns,))
536        operation.append("from stdin")
537        if options:
538            operation.append('(%s)' % ','.join(options))
539        operation = ' '.join(operation)
540
541        putdata = self._src.putdata
542        self.execute(operation, params)
543
544        try:
545            for chunk in chunks():
546                putdata(chunk)
547        except BaseException as error:
548            self.rowcount = -1
549            # the following call will re-raise the error
550            putdata(error)
551        else:
552            self.rowcount = putdata(None)
553
554        # return the cursor object, so you can chain operations
555        return self
556
557    def copy_to(self, stream, table,
558            format=None, sep=None, null=None, decode=None, columns=None):
559        """Copy data from the specified table to an output stream.
560
561        The output stream can be a file-like object with a write() method or
562        it can also be None, in which case the method will return a generator
563        yielding a row on each iteration.
564
565        Output will be returned as byte strings unless you set decode to true.
566
567        Note that you can also use a select query instead of the table name.
568
569        The format must be text, csv or binary. The sep option sets the
570        column separator (delimiter) used in the non binary formats.
571        The null option sets the textual representation of NULL in the output.
572
573        The copy operation can be restricted to a subset of columns. If no
574        columns are specified, all of them will be copied.
575        """
576        binary_format = format == 'binary'
577        if stream is not None:
578            try:
579                write = stream.write
580            except AttributeError:
581                raise TypeError("need an output stream to copy to")
582        if not table or not isinstance(table, basestring):
583            raise TypeError("need a table to copy to")
584        if table.lower().startswith('select'):
585            if columns:
586                raise ValueError("columns must be specified in the query")
587            table = '(%s)' % (table,)
588        else:
589            table = '"%s"' % (table,)
590        operation = ['copy %s' % (table,)]
591        options = []
592        params = []
593        if format is not None:
594            if not isinstance(format, basestring):
595                raise TypeError("format option must be a string")
596            if format not in ('text', 'csv', 'binary'):
597                raise ValueError("invalid format")
598            options.append('format %s' % (format,))
599        if sep is not None:
600            if not isinstance(sep, basestring):
601                raise TypeError("sep option must be a string")
602            if binary_format:
603                raise ValueError("sep is not allowed with binary format")
604            if len(sep) != 1:
605                raise ValueError("sep must be a single one-byte character")
606            options.append('delimiter %s')
607            params.append(sep)
608        if null is not None:
609            if not isinstance(null, basestring):
610                raise TypeError("null option must be a string")
611            options.append('null %s')
612            params.append(null)
613        if decode is None:
614            if format == 'binary':
615                decode = False
616            else:
617                decode = str is unicode
618        else:
619            if not isinstance(decode, (int, bool)):
620                raise TypeError("decode option must be a boolean")
621            if decode and binary_format:
622                raise ValueError("decode is not allowed with binary format")
623        if columns:
624            if not isinstance(columns, basestring):
625                columns = ','.join('"%s"' % (col,) for col in columns)
626            operation.append('(%s)' % (columns,))
627
628        operation.append("to stdout")
629        if options:
630            operation.append('(%s)' % ','.join(options))
631        operation = ' '.join(operation)
632
633        getdata = self._src.getdata
634        self.execute(operation, params)
635
636        def copy():
637            self.rowcount = 0
638            while True:
639                row = getdata(decode)
640                if isinstance(row, int):
641                    if self.rowcount != row:
642                        self.rowcount = row
643                    break
644                self.rowcount += 1
645                yield row
646
647        if stream is None:
648            # no input stream, return the generator
649            return copy()
650
651        # write the rows to the file-like input stream
652        for row in copy():
653            write(row)
654
655        # return the cursor object, so you can chain operations
656        return self
657
658    def __next__(self):
659        """Return the next row (support for the iteration protocol)."""
660        res = self.fetchone()
661        if res is None:
662            raise StopIteration
663        return res
664
665    # Note that since Python 2.6 the iterator protocol uses __next()__
666    # instead of next(), we keep it only for backward compatibility of pgdb.
667    next = __next__
668
669    @staticmethod
670    def nextset():
671        """Not supported."""
672        raise NotSupportedError("nextset() is not supported")
673
674    @staticmethod
675    def setinputsizes(sizes):
676        """Not supported."""
677        pass  # unsupported, but silently passed
678
679    @staticmethod
680    def setoutputsize(size, column=0):
681        """Not supported."""
682        pass  # unsupported, but silently passed
683
684    @staticmethod
685    def row_factory(row):
686        """Process rows before they are returned.
687
688        You can overwrite this statically with a custom row factory, or
689        you can build a row factory dynamically with build_row_factory().
690
691        For example, you can create a Cursor class that returns rows as
692        Python dictionaries like this:
693
694            class DictCursor(pgdb.Cursor):
695
696                def row_factory(self, row):
697                    return {desc[0]: value
698                        for desc, value in zip(self.description, row)}
699
700            cur = DictCursor(con)  # get one DictCursor instance or
701            con.cursor_type = DictCursor  # always use DictCursor instances
702        """
703        raise NotImplementedError
704
705    def build_row_factory(self):
706        """Build a row factory based on the current description.
707
708        This implementation builds a row factory for creating named tuples.
709        You can overwrite this method if you want to dynamically create
710        different row factories whenever the column description changes.
711        """
712        colnames = self.colnames
713        if colnames:
714            try:
715                try:
716                    return namedtuple('Row', colnames, rename=True)._make
717                except TypeError:  # Python 2.6 and 3.0 do not support rename
718                    colnames = [v if v.isalnum() else 'column_%d' % n
719                             for n, v in enumerate(colnames)]
720                    return namedtuple('Row', colnames)._make
721            except ValueError:  # there is still a problem with the field names
722                colnames = ['column_%d' % n for n in range(len(colnames))]
723                return namedtuple('Row', colnames)._make
724
725
726CursorDescription = namedtuple('CursorDescription',
727    ['name', 'type_code', 'display_size', 'internal_size',
728     'precision', 'scale', 'null_ok'])
729
730
731### Connection Objects
732
733class Connection(object):
734    """Connection object."""
735
736    # expose the exceptions as attributes on the connection object
737    Error = Error
738    Warning = Warning
739    InterfaceError = InterfaceError
740    DatabaseError = DatabaseError
741    InternalError = InternalError
742    OperationalError = OperationalError
743    ProgrammingError = ProgrammingError
744    IntegrityError = IntegrityError
745    DataError = DataError
746    NotSupportedError = NotSupportedError
747
748    def __init__(self, cnx):
749        """Create a database connection object."""
750        self._cnx = cnx  # connection
751        self._tnx = False  # transaction state
752        self._type_cache = TypeCache(cnx)
753        self.cursor_type = Cursor
754        try:
755            self._cnx.source()
756        except Exception:
757            raise _op_error("invalid connection")
758
759    def __enter__(self):
760        """Enter the runtime context for the connection object.
761
762        The runtime context can be used for running transactions.
763        """
764        return self
765
766    def __exit__(self, et, ev, tb):
767        """Exit the runtime context for the connection object.
768
769        This does not close the connection, but it ends a transaction.
770        """
771        if et is None and ev is None and tb is None:
772            self.commit()
773        else:
774            self.rollback()
775
776    def close(self):
777        """Close the connection object."""
778        if self._cnx:
779            if self._tnx:
780                try:
781                    self.rollback()
782                except DatabaseError:
783                    pass
784            self._cnx.close()
785            self._cnx = None
786        else:
787            raise _op_error("connection has been closed")
788
789    def commit(self):
790        """Commit any pending transaction to the database."""
791        if self._cnx:
792            if self._tnx:
793                self._tnx = False
794                try:
795                    self._cnx.source().execute("COMMIT")
796                except DatabaseError:
797                    raise
798                except Exception:
799                    raise _op_error("can't commit")
800        else:
801            raise _op_error("connection has been closed")
802
803    def rollback(self):
804        """Roll back to the start of any pending transaction."""
805        if self._cnx:
806            if self._tnx:
807                self._tnx = False
808                try:
809                    self._cnx.source().execute("ROLLBACK")
810                except DatabaseError:
811                    raise
812                except Exception:
813                    raise _op_error("can't rollback")
814        else:
815            raise _op_error("connection has been closed")
816
817    def cursor(self):
818        """Return a new cursor object using the connection."""
819        if self._cnx:
820            try:
821                return self.cursor_type(self)
822            except Exception:
823                raise _op_error("invalid connection")
824        else:
825            raise _op_error("connection has been closed")
826
827    if shortcutmethods:  # otherwise do not implement and document this
828
829        def execute(self, operation, params=None):
830            """Shortcut method to run an operation on an implicit cursor."""
831            cursor = self.cursor()
832            cursor.execute(operation, params)
833            return cursor
834
835        def executemany(self, operation, param_seq):
836            """Shortcut method to run an operation against a sequence."""
837            cursor = self.cursor()
838            cursor.executemany(operation, param_seq)
839            return cursor
840
841
842### Module Interface
843
844_connect_ = connect
845
846def connect(dsn=None,
847        user=None, password=None,
848        host=None, database=None):
849    """Connect to a database."""
850    # first get params from DSN
851    dbport = -1
852    dbhost = ""
853    dbbase = ""
854    dbuser = ""
855    dbpasswd = ""
856    dbopt = ""
857    try:
858        params = dsn.split(":")
859        dbhost = params[0]
860        dbbase = params[1]
861        dbuser = params[2]
862        dbpasswd = params[3]
863        dbopt = params[4]
864    except (AttributeError, IndexError, TypeError):
865        pass
866
867    # override if necessary
868    if user is not None:
869        dbuser = user
870    if password is not None:
871        dbpasswd = password
872    if database is not None:
873        dbbase = database
874    if host is not None:
875        try:
876            params = host.split(":")
877            dbhost = params[0]
878            dbport = int(params[1])
879        except (AttributeError, IndexError, TypeError, ValueError):
880            pass
881
882    # empty host is localhost
883    if dbhost == "":
884        dbhost = None
885    if dbuser == "":
886        dbuser = None
887
888    # open the connection
889    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
890    return Connection(cnx)
891
892
893### Types Handling
894
895class Type(frozenset):
896    """Type class for a couple of PostgreSQL data types.
897
898    PostgreSQL is object-oriented: types are dynamic.
899    We must thus use type names as internal type codes.
900    """
901
902    def __new__(cls, values):
903        if isinstance(values, basestring):
904            values = values.split()
905        return super(Type, cls).__new__(cls, values)
906
907    def __eq__(self, other):
908        if isinstance(other, basestring):
909            return other in self
910        else:
911            return super(Type, self).__eq__(other)
912
913    def __ne__(self, other):
914        if isinstance(other, basestring):
915            return other not in self
916        else:
917            return super(Type, self).__ne__(other)
918
919
920# Mandatory type objects defined by DB-API 2 specs:
921
922STRING = Type('char bpchar name text varchar')
923BINARY = Type('bytea')
924NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
925DATETIME = Type('date time timetz timestamp timestamptz datetime abstime'
926    ' interval tinterval timespan reltime')
927ROWID = Type('oid oid8')
928
929
930# Additional type objects (more specific):
931
932BOOL = Type('bool')
933SMALLINT = Type('int2')
934INTEGER = Type('int2 int4 int8 serial')
935LONG = Type('int8')
936FLOAT = Type('float4 float8')
937NUMERIC = Type('numeric')
938MONEY = Type('money')
939DATE = Type('date')
940TIME = Type('time timetz')
941TIMESTAMP = Type('timestamp timestamptz datetime abstime')
942INTERVAL = Type('interval tinterval timespan reltime')
943
944
945# Mandatory type helpers defined by DB-API 2 specs:
946
947def Date(year, month, day):
948    """Construct an object holding a date value."""
949    return date(year, month, day)
950
951
952def Time(hour, minute=0, second=0, microsecond=0):
953    """Construct an object holding a time value."""
954    return time(hour, minute, second, microsecond)
955
956
957def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
958    """construct an object holding a time stamp valu."""
959    return datetime(year, month, day, hour, minute, second, microsecond)
960
961
962def DateFromTicks(ticks):
963    """Construct an object holding a date value from the given ticks value."""
964    return Date(*localtime(ticks)[:3])
965
966
967def TimeFromTicks(ticks):
968    """construct an object holding a time value from the given ticks value."""
969    return Time(*localtime(ticks)[3:6])
970
971
972def TimestampFromTicks(ticks):
973    """construct an object holding a time stamp from the given ticks value."""
974    return Timestamp(*localtime(ticks)[:6])
975
976
977class Binary(bytes):
978    """construct an object capable of holding a binary (long) string value."""
979
980
981# If run as script, print some information:
982
983if __name__ == '__main__':
984    print('PyGreSQL version', version)
985    print('')
986    print(__doc__)
Note: See TracBrowser for help on using the repository browser.