source: trunk/pgdb.py @ 773

Last change on this file since 773 was 773, checked in by cito, 4 years ago

Do not call set_decimal in pgdb

The set_decimal() function is only relevant for the pg module (its getresult() method).
Don't call it in the pgdb module; the two modules should not interfere with each other.

  • Property svn:keywords set to Id
File size: 31.8 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 773 2016-01-21 13:04:52Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75
76try:
77    long
78except NameError:  # Python >= 3.0
79    long = int
80
81try:
82    unicode
83except NameError:  # Python >= 3.0
84    unicode = str
85
86try:
87    basestring
88except NameError:  # Python >= 3.0
89    basestring = (str, bytes)
90
91from collections import Iterable
92try:
93    from collections import OrderedDict
94except ImportError:  # Python 2.6 or 3.0
95    try:
96        from ordereddict import OrderedDict
97    except Exception:
98        def OrderedDict(*args):
99            raise NotSupportedError('OrderedDict is not supported')
100
101
102### Module Constants
103
104# compliant with DB API 2.0
105apilevel = '2.0'
106
107# module may be shared, but not connections
108threadsafety = 1
109
110# this module use extended python format codes
111paramstyle = 'pyformat'
112
113# shortcut methods have been excluded from DB API 2 and
114# are not recommended by the DB SIG, but they can be handy
115shortcutmethods = 1
116
117
118### Internal Types Handling
119
120def decimal_type(decimal_type=None):
121    """Get or set global type to be used for decimal values."""
122    global Decimal
123    if decimal_type is not None:
124        _cast['numeric'] = Decimal = decimal_type
125    return Decimal
126
127
128def _cast_bool(value):
129    return value[:1] in ('t', 'T')
130
131
132def _cast_money(value):
133    return Decimal(''.join(filter(
134        lambda v: v in '0123456789.-', value)))
135
136
137def _cast_bytea(value):
138    return unescape_bytea(value)
139
140
141def _cast_float(value):
142    return float(value)  # this also works with NaN and Infinity
143
144
145_cast = {'bool': _cast_bool, 'bytea': _cast_bytea,
146    'int2': int, 'int4': int, 'serial': int,
147    'int8': long, 'oid': long, 'oid8': long,
148    'float4': _cast_float, 'float8': _cast_float,
149    'numeric': Decimal, 'money': _cast_money}
150
151
152def _db_error(msg, cls=DatabaseError):
153    """Return DatabaseError with empty sqlstate attribute."""
154    error = cls(msg)
155    error.sqlstate = None
156    return error
157
158
159def _op_error(msg):
160    """Return OperationalError."""
161    return _db_error(msg, OperationalError)
162
163
164class TypeCache(dict):
165    """Cache for database types."""
166
167    def __init__(self, cnx):
168        """Initialize type cache for connection."""
169        super(TypeCache, self).__init__()
170        self._src = cnx.source()
171
172    @staticmethod
173    def typecast(typ, value):
174        """Cast value to database type."""
175        if value is None:
176            # for NULL values, no typecast is necessary
177            return None
178        cast = _cast.get(typ)
179        if cast is None:
180            # no typecast available or necessary
181            return value
182        else:
183            return cast(value)
184
185    def getdescr(self, oid):
186        """Get name of database type with given oid."""
187        try:
188            return self[oid]
189        except KeyError:
190            self._src.execute(
191                "SELECT typname, typlen "
192                "FROM pg_type WHERE oid=%s" % oid)
193            res = self._src.fetch(1)[0]
194            # The column name is omitted from the return value.
195            # It will have to be prepended by the caller.
196            res = (res[0], None, int(res[1]), None, None, None)
197            self[oid] = res
198            return res
199
200
201class _quotedict(dict):
202    """Dictionary with auto quoting of its items.
203
204    The quote attribute must be set to the desired quote function.
205    """
206
207    def __getitem__(self, key):
208        return self.quote(super(_quotedict, self).__getitem__(key))
209
210
211### Cursor Object
212
213class Cursor(object):
214    """Cursor object."""
215
216    def __init__(self, dbcnx):
217        """Create a cursor object for the database connection."""
218        self.connection = self._dbcnx = dbcnx
219        self._cnx = dbcnx._cnx
220        self._type_cache = dbcnx._type_cache
221        self._src = self._cnx.source()
222        # the official attribute for describing the result columns
223        self.description = None
224        # unofficial attributes for convenience and performance
225        self.colnames = self.coltypes = None
226        if self.row_factory is Cursor.row_factory:
227            # the row factory needs to be determined dynamically
228            self.row_factory = None
229        else:
230            self.build_row_factory = None
231        self.rowcount = -1
232        self.arraysize = 1
233        self.lastrowid = None
234
235    def __iter__(self):
236        """Make cursor compatible to the iteration protocol."""
237        return self
238
239    def __enter__(self):
240        """Enter the runtime context for the cursor object."""
241        return self
242
243    def __exit__(self, et, ev, tb):
244        """Exit the runtime context for the cursor object."""
245        self.close()
246
247    def _quote(self, val):
248        """Quote value depending on its type."""
249        if isinstance(val, (datetime, date, time, timedelta)):
250            val = str(val)
251        if isinstance(val, basestring):
252            if isinstance(val, Binary):
253                val = self._cnx.escape_bytea(val)
254                if bytes is not str:  # Python >= 3.0
255                    val = val.decode('ascii')
256            else:
257                val = self._cnx.escape_string(val)
258            val = "'%s'" % val
259        elif isinstance(val, (int, long)):
260            pass
261        elif isinstance(val, float):
262            if isinf(val):
263                return "'-Infinity'" if val < 0 else "'Infinity'"
264            elif isnan(val):
265                return "'NaN'"
266        elif val is None:
267            val = 'NULL'
268        elif isinstance(val, (list, tuple)):
269            q = self._quote
270            val = 'ARRAY[%s]' % ','.join(str(q(v)) for v in val)
271        elif Decimal is not float and isinstance(val, Decimal):
272            pass
273        elif hasattr(val, '__pg_repr__'):
274            val = val.__pg_repr__()
275        else:
276            raise InterfaceError(
277                'do not know how to handle type %s' % type(val))
278        return val
279
280    def _quoteparams(self, string, parameters):
281        """Quote parameters.
282
283        This function works for both mappings and sequences.
284        """
285        if isinstance(parameters, dict):
286            parameters = _quotedict(parameters)
287            parameters.quote = self._quote
288        else:
289            parameters = tuple(map(self._quote, parameters))
290        return string % parameters
291
292    def close(self):
293        """Close the cursor object."""
294        self._src.close()
295        self.description = None
296        self.colnames = self.coltypes = None
297        self.rowcount = -1
298        self.lastrowid = None
299
300    def execute(self, operation, parameters=None):
301        """Prepare and execute a database operation (query or command)."""
302
303        # The parameters may also be specified as list of
304        # tuples to e.g. insert multiple rows in a single
305        # operation, but this kind of usage is deprecated:
306        if (parameters and isinstance(parameters, list) and
307                isinstance(parameters[0], tuple)):
308            return self.executemany(operation, parameters)
309        else:
310            # not a list of tuples
311            return self.executemany(operation, [parameters])
312
313    def executemany(self, operation, seq_of_parameters):
314        """Prepare operation and execute it against a parameter sequence."""
315        if not seq_of_parameters:
316            # don't do anything without parameters
317            return
318        self.description = None
319        self.colnames = self.coltypes = None
320        self.rowcount = -1
321        # first try to execute all queries
322        rowcount = 0
323        sql = "BEGIN"
324        try:
325            if not self._dbcnx._tnx:
326                try:
327                    self._cnx.source().execute(sql)
328                except DatabaseError:
329                    raise  # database provides error message
330                except Exception as err:
331                    raise _op_error("can't start transaction")
332                self._dbcnx._tnx = True
333            for parameters in seq_of_parameters:
334                if parameters:
335                    sql = self._quoteparams(operation, parameters)
336                else:
337                    sql = operation
338                rows = self._src.execute(sql)
339                if rows:  # true if not DML
340                    rowcount += rows
341                else:
342                    self.rowcount = -1
343        except DatabaseError:
344            raise  # database provides error message
345        except Error as err:
346            raise _db_error(
347                "error in '%s': '%s' " % (sql, err), InterfaceError)
348        except Exception as err:
349            raise _op_error("internal error in '%s': %s" % (sql, err))
350        # then initialize result raw count and description
351        if self._src.resulttype == RESULT_DQL:
352            self.rowcount = self._src.ntuples
353            getdescr = self._type_cache.getdescr
354            description = [CursorDescription(
355                info[1], *getdescr(info[2])) for info in self._src.listinfo()]
356            self.colnames = [info[0] for info in description]
357            self.coltypes = [info[1] for info in description]
358            self.description = description
359            self.lastrowid = None
360            if self.build_row_factory:
361                self.row_factory = self.build_row_factory()
362        else:
363            self.rowcount = rowcount
364            self.lastrowid = self._src.oidstatus()
365        # return the cursor object, so you can write statements such as
366        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
367        return self
368
369    def fetchone(self):
370        """Fetch the next row of a query result set."""
371        res = self.fetchmany(1, False)
372        try:
373            return res[0]
374        except IndexError:
375            return None
376
377    def fetchall(self):
378        """Fetch all (remaining) rows of a query result."""
379        return self.fetchmany(-1, False)
380
381    def fetchmany(self, size=None, keep=False):
382        """Fetch the next set of rows of a query result.
383
384        The number of rows to fetch per call is specified by the
385        size parameter. If it is not given, the cursor's arraysize
386        determines the number of rows to be fetched. If you set
387        the keep parameter to true, this is kept as new arraysize.
388        """
389        if size is None:
390            size = self.arraysize
391        if keep:
392            self.arraysize = size
393        try:
394            result = self._src.fetch(size)
395        except DatabaseError:
396            raise
397        except Error as err:
398            raise _db_error(str(err))
399        typecast = self._type_cache.typecast
400        return [self.row_factory([typecast(typ, value)
401            for typ, value in zip(self.coltypes, row)]) for row in result]
402
403    def callproc(self, procname, parameters=None):
404        """Call a stored database procedure with the given name.
405
406        The sequence of parameters must contain one entry for each input
407        argument that the procedure expects. The result of the call is the
408        same as this input sequence; replacement of output and input/output
409        parameters in the return value is currently not supported.
410
411        The procedure may also provide a result set as output. These can be
412        requested through the standard fetch methods of the cursor.
413        """
414        n = parameters and len(parameters) or 0
415        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
416        self.execute(query, parameters)
417        return parameters
418
419    def copy_from(self, stream, table,
420            format=None, sep=None, null=None, size=None, columns=None):
421        """Copy data from an input stream to the specified table.
422
423        The input stream can be a file-like object with a read() method or
424        it can also be an iterable returning a row or multiple rows of input
425        on each iteration.
426
427        The format must be text, csv or binary. The sep option sets the
428        column separator (delimiter) used in the non binary formats.
429        The null option sets the textual representation of NULL in the input.
430
431        The size option sets the size of the buffer used when reading data
432        from file-like objects.
433
434        The copy operation can be restricted to a subset of columns. If no
435        columns are specified, all of them will be copied.
436        """
437        binary_format = format == 'binary'
438        try:
439            read = stream.read
440        except AttributeError:
441            if size:
442                raise ValueError("size must only be set for file-like objects")
443            if binary_format:
444                input_type = bytes
445                type_name = 'byte strings'
446            else:
447                input_type = basestring
448                type_name = 'strings'
449
450            if isinstance(stream, basestring):
451                if not isinstance(stream, input_type):
452                    raise ValueError("The input must be %s" % type_name)
453                if not binary_format:
454                    if isinstance(stream, str):
455                        if not stream.endswith('\n'):
456                            stream += '\n'
457                    else:
458                        if not stream.endswith(b'\n'):
459                            stream += b'\n'
460
461                def chunks():
462                    yield stream
463
464            elif isinstance(stream, Iterable):
465
466                def chunks():
467                    for chunk in stream:
468                        if not isinstance(chunk, input_type):
469                            raise ValueError(
470                                "Input stream must consist of %s" % type_name)
471                        if isinstance(chunk, str):
472                            if not chunk.endswith('\n'):
473                                chunk += '\n'
474                        else:
475                            if not chunk.endswith(b'\n'):
476                                chunk += b'\n'
477                        yield chunk
478
479            else:
480                raise TypeError("Need an input stream to copy from")
481        else:
482            if size is None:
483                size = 8192
484            elif not isinstance(size, int):
485                raise TypeError("The size option must be an integer")
486            if size > 0:
487
488                def chunks():
489                    while True:
490                        buffer = read(size)
491                        yield buffer
492                        if not buffer or len(buffer) < size:
493                            break
494
495            else:
496
497                def chunks():
498                    yield read()
499
500        if not table or not isinstance(table, basestring):
501            raise TypeError("Need a table to copy to")
502        if table.lower().startswith('select'):
503                raise ValueError("Must specify a table, not a query")
504        else:
505            table = '"%s"' % (table,)
506        operation = ['copy %s' % (table,)]
507        options = []
508        params = []
509        if format is not None:
510            if not isinstance(format, basestring):
511                raise TypeError("format option must be be a string")
512            if format not in ('text', 'csv', 'binary'):
513                raise ValueError("invalid format")
514            options.append('format %s' % (format,))
515        if sep is not None:
516            if not isinstance(sep, basestring):
517                raise TypeError("sep option must be a string")
518            if format == 'binary':
519                raise ValueError("sep is not allowed with binary format")
520            if len(sep) != 1:
521                raise ValueError("sep must be a single one-byte character")
522            options.append('delimiter %s')
523            params.append(sep)
524        if null is not None:
525            if not isinstance(null, basestring):
526                raise TypeError("null option must be a string")
527            options.append('null %s')
528            params.append(null)
529        if columns:
530            if not isinstance(columns, basestring):
531                columns = ','.join('"%s"' % (col,) for col in columns)
532            operation.append('(%s)' % (columns,))
533        operation.append("from stdin")
534        if options:
535            operation.append('(%s)' % ','.join(options))
536        operation = ' '.join(operation)
537
538        putdata = self._src.putdata
539        self.execute(operation, params)
540
541        try:
542            for chunk in chunks():
543                putdata(chunk)
544        except BaseException as error:
545            self.rowcount = -1
546            # the following call will re-raise the error
547            putdata(error)
548        else:
549            self.rowcount = putdata(None)
550
551        # return the cursor object, so you can chain operations
552        return self
553
554    def copy_to(self, stream, table,
555            format=None, sep=None, null=None, decode=None, columns=None):
556        """Copy data from the specified table to an output stream.
557
558        The output stream can be a file-like object with a write() method or
559        it can also be None, in which case the method will return a generator
560        yielding a row on each iteration.
561
562        Output will be returned as byte strings unless you set decode to true.
563
564        Note that you can also use a select query instead of the table name.
565
566        The format must be text, csv or binary. The sep option sets the
567        column separator (delimiter) used in the non binary formats.
568        The null option sets the textual representation of NULL in the output.
569
570        The copy operation can be restricted to a subset of columns. If no
571        columns are specified, all of them will be copied.
572        """
573        binary_format = format == 'binary'
574        if stream is not None:
575            try:
576                write = stream.write
577            except AttributeError:
578                raise TypeError("need an output stream to copy to")
579        if not table or not isinstance(table, basestring):
580            raise TypeError("need a table to copy to")
581        if table.lower().startswith('select'):
582            if columns:
583                raise ValueError("columns must be specified in the query")
584            table = '(%s)' % (table,)
585        else:
586            table = '"%s"' % (table,)
587        operation = ['copy %s' % (table,)]
588        options = []
589        params = []
590        if format is not None:
591            if not isinstance(format, basestring):
592                raise TypeError("format option must be a string")
593            if format not in ('text', 'csv', 'binary'):
594                raise ValueError("invalid format")
595            options.append('format %s' % (format,))
596        if sep is not None:
597            if not isinstance(sep, basestring):
598                raise TypeError("sep option must be a string")
599            if binary_format:
600                raise ValueError("sep is not allowed with binary format")
601            if len(sep) != 1:
602                raise ValueError("sep must be a single one-byte character")
603            options.append('delimiter %s')
604            params.append(sep)
605        if null is not None:
606            if not isinstance(null, basestring):
607                raise TypeError("null option must be a string")
608            options.append('null %s')
609            params.append(null)
610        if decode is None:
611            if format == 'binary':
612                decode = False
613            else:
614                decode = str is unicode
615        else:
616            if not isinstance(decode, (int, bool)):
617                raise TypeError("decode option must be a boolean")
618            if decode and binary_format:
619                raise ValueError("decode is not allowed with binary format")
620        if columns:
621            if not isinstance(columns, basestring):
622                columns = ','.join('"%s"' % (col,) for col in columns)
623            operation.append('(%s)' % (columns,))
624
625        operation.append("to stdout")
626        if options:
627            operation.append('(%s)' % ','.join(options))
628        operation = ' '.join(operation)
629
630        getdata = self._src.getdata
631        self.execute(operation, params)
632
633        def copy():
634            self.rowcount = 0
635            while True:
636                row = getdata(decode)
637                if isinstance(row, int):
638                    if self.rowcount != row:
639                        self.rowcount = row
640                    break
641                self.rowcount += 1
642                yield row
643
644        if stream is None:
645            # no input stream, return the generator
646            return copy()
647
648        # write the rows to the file-like input stream
649        for row in copy():
650            write(row)
651
652        # return the cursor object, so you can chain operations
653        return self
654
655    def __next__(self):
656        """Return the next row (support for the iteration protocol)."""
657        res = self.fetchone()
658        if res is None:
659            raise StopIteration
660        return res
661
662    # Note that since Python 2.6 the iterator protocol uses __next()__
663    # instead of next(), we keep it only for backward compatibility of pgdb.
664    next = __next__
665
666    @staticmethod
667    def nextset():
668        """Not supported."""
669        raise NotSupportedError("nextset() is not supported")
670
671    @staticmethod
672    def setinputsizes(sizes):
673        """Not supported."""
674        pass  # unsupported, but silently passed
675
676    @staticmethod
677    def setoutputsize(size, column=0):
678        """Not supported."""
679        pass  # unsupported, but silently passed
680
681    @staticmethod
682    def row_factory(row):
683        """Process rows before they are returned.
684
685        You can overwrite this statically with a custom row factory, or
686        you can build a row factory dynamically with build_row_factory().
687
688        For example, you can create a Cursor class that returns rows as
689        Python dictionaries like this:
690
691            class DictCursor(pgdb.Cursor):
692
693                def row_factory(self, row):
694                    return {desc[0]: value
695                        for desc, value in zip(self.description, row)}
696
697            cur = DictCursor(con)  # get one DictCursor instance or
698            con.cursor_type = DictCursor  # always use DictCursor instances
699        """
700        raise NotImplementedError
701
702    def build_row_factory(self):
703        """Build a row factory based on the current description.
704
705        This implementation builds a row factory for creating named tuples.
706        You can overwrite this method if you want to dynamically create
707        different row factories whenever the column description changes.
708        """
709        colnames = self.colnames
710        if colnames:
711            try:
712                try:
713                    return namedtuple('Row', colnames, rename=True)._make
714                except TypeError:  # Python 2.6 and 3.0 do not support rename
715                    colnames = [v if v.isalnum() else 'column_%d' % n
716                             for n, v in enumerate(colnames)]
717                    return namedtuple('Row', colnames)._make
718            except ValueError:  # there is still a problem with the field names
719                colnames = ['column_%d' % n for n in range(len(colnames))]
720                return namedtuple('Row', colnames)._make
721
722
723CursorDescription = namedtuple('CursorDescription',
724    ['name', 'type_code', 'display_size', 'internal_size',
725     'precision', 'scale', 'null_ok'])
726
727
728### Connection Objects
729
730class Connection(object):
731    """Connection object."""
732
733    # expose the exceptions as attributes on the connection object
734    Error = Error
735    Warning = Warning
736    InterfaceError = InterfaceError
737    DatabaseError = DatabaseError
738    InternalError = InternalError
739    OperationalError = OperationalError
740    ProgrammingError = ProgrammingError
741    IntegrityError = IntegrityError
742    DataError = DataError
743    NotSupportedError = NotSupportedError
744
745    def __init__(self, cnx):
746        """Create a database connection object."""
747        self._cnx = cnx  # connection
748        self._tnx = False  # transaction state
749        self._type_cache = TypeCache(cnx)
750        self.cursor_type = Cursor
751        try:
752            self._cnx.source()
753        except Exception:
754            raise _op_error("invalid connection")
755
756    def __enter__(self):
757        """Enter the runtime context for the connection object.
758
759        The runtime context can be used for running transactions.
760        """
761        return self
762
763    def __exit__(self, et, ev, tb):
764        """Exit the runtime context for the connection object.
765
766        This does not close the connection, but it ends a transaction.
767        """
768        if et is None and ev is None and tb is None:
769            self.commit()
770        else:
771            self.rollback()
772
773    def close(self):
774        """Close the connection object."""
775        if self._cnx:
776            if self._tnx:
777                try:
778                    self.rollback()
779                except DatabaseError:
780                    pass
781            self._cnx.close()
782            self._cnx = None
783        else:
784            raise _op_error("connection has been closed")
785
786    def commit(self):
787        """Commit any pending transaction to the database."""
788        if self._cnx:
789            if self._tnx:
790                self._tnx = False
791                try:
792                    self._cnx.source().execute("COMMIT")
793                except DatabaseError:
794                    raise
795                except Exception:
796                    raise _op_error("can't commit")
797        else:
798            raise _op_error("connection has been closed")
799
800    def rollback(self):
801        """Roll back to the start of any pending transaction."""
802        if self._cnx:
803            if self._tnx:
804                self._tnx = False
805                try:
806                    self._cnx.source().execute("ROLLBACK")
807                except DatabaseError:
808                    raise
809                except Exception:
810                    raise _op_error("can't rollback")
811        else:
812            raise _op_error("connection has been closed")
813
814    def cursor(self):
815        """Return a new cursor object using the connection."""
816        if self._cnx:
817            try:
818                return self.cursor_type(self)
819            except Exception:
820                raise _op_error("invalid connection")
821        else:
822            raise _op_error("connection has been closed")
823
824    if shortcutmethods:  # otherwise do not implement and document this
825
826        def execute(self, operation, params=None):
827            """Shortcut method to run an operation on an implicit cursor."""
828            cursor = self.cursor()
829            cursor.execute(operation, params)
830            return cursor
831
832        def executemany(self, operation, param_seq):
833            """Shortcut method to run an operation against a sequence."""
834            cursor = self.cursor()
835            cursor.executemany(operation, param_seq)
836            return cursor
837
838
839### Module Interface
840
841_connect_ = connect
842
843def connect(dsn=None,
844        user=None, password=None,
845        host=None, database=None):
846    """Connect to a database."""
847    # first get params from DSN
848    dbport = -1
849    dbhost = ""
850    dbbase = ""
851    dbuser = ""
852    dbpasswd = ""
853    dbopt = ""
854    try:
855        params = dsn.split(":")
856        dbhost = params[0]
857        dbbase = params[1]
858        dbuser = params[2]
859        dbpasswd = params[3]
860        dbopt = params[4]
861    except (AttributeError, IndexError, TypeError):
862        pass
863
864    # override if necessary
865    if user is not None:
866        dbuser = user
867    if password is not None:
868        dbpasswd = password
869    if database is not None:
870        dbbase = database
871    if host is not None:
872        try:
873            params = host.split(":")
874            dbhost = params[0]
875            dbport = int(params[1])
876        except (AttributeError, IndexError, TypeError, ValueError):
877            pass
878
879    # empty host is localhost
880    if dbhost == "":
881        dbhost = None
882    if dbuser == "":
883        dbuser = None
884
885    # open the connection
886    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
887    return Connection(cnx)
888
889
890### Types Handling
891
892class Type(frozenset):
893    """Type class for a couple of PostgreSQL data types.
894
895    PostgreSQL is object-oriented: types are dynamic.
896    We must thus use type names as internal type codes.
897    """
898
899    def __new__(cls, values):
900        if isinstance(values, basestring):
901            values = values.split()
902        return super(Type, cls).__new__(cls, values)
903
904    def __eq__(self, other):
905        if isinstance(other, basestring):
906            return other in self
907        else:
908            return super(Type, self).__eq__(other)
909
910    def __ne__(self, other):
911        if isinstance(other, basestring):
912            return other not in self
913        else:
914            return super(Type, self).__ne__(other)
915
916
917# Mandatory type objects defined by DB-API 2 specs:
918
919STRING = Type('char bpchar name text varchar')
920BINARY = Type('bytea')
921NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
922DATETIME = Type('date time timetz timestamp timestamptz datetime abstime'
923    ' interval tinterval timespan reltime')
924ROWID = Type('oid oid8')
925
926
927# Additional type objects (more specific):
928
929BOOL = Type('bool')
930SMALLINT = Type('int2')
931INTEGER = Type('int2 int4 int8 serial')
932LONG = Type('int8')
933FLOAT = Type('float4 float8')
934NUMERIC = Type('numeric')
935MONEY = Type('money')
936DATE = Type('date')
937TIME = Type('time timetz')
938TIMESTAMP = Type('timestamp timestamptz datetime abstime')
939INTERVAL = Type('interval tinterval timespan reltime')
940
941
942# Mandatory type helpers defined by DB-API 2 specs:
943
944def Date(year, month, day):
945    """Construct an object holding a date value."""
946    return date(year, month, day)
947
948
949def Time(hour, minute=0, second=0, microsecond=0):
950    """Construct an object holding a time value."""
951    return time(hour, minute, second, microsecond)
952
953
954def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
955    """construct an object holding a time stamp valu."""
956    return datetime(year, month, day, hour, minute, second, microsecond)
957
958
959def DateFromTicks(ticks):
960    """Construct an object holding a date value from the given ticks value."""
961    return Date(*localtime(ticks)[:3])
962
963
964def TimeFromTicks(ticks):
965    """construct an object holding a time value from the given ticks value."""
966    return Time(*localtime(ticks)[3:6])
967
968
969def TimestampFromTicks(ticks):
970    """construct an object holding a time stamp from the given ticks value."""
971    return Timestamp(*localtime(ticks)[:6])
972
973
974class Binary(bytes):
975    """construct an object capable of holding a binary (long) string value."""
976
977
978# If run as script, print some information:
979
980if __name__ == '__main__':
981    print('PyGreSQL version', version)
982    print('')
983    print(__doc__)
Note: See TracBrowser for help on using the repository browser.