source: trunk/pgdb.py @ 740

Last change on this file since 740 was 740, checked in by cito, 4 years ago

Reformat some error messages and docstrings

Try to achieve a somewhat consistent style of docstrings and error
messages in the trunk. The docstrings use PEP 257, with a slight
variation between C code and Python code. The error messages are
capitalized and do not end with a period. (I prefer the periods,
but most Python code I have seen doesn't use them.)

(I know, a foolish consistency is the hobgoblin of little minds.)

  • Property svn:keywords set to Id
File size: 32.0 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 740 2016-01-14 02:35:55Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75
76try:
77    long
78except NameError:  # Python >= 3.0
79    long = int
80
81try:
82    unicode
83except NameError:  # Python >= 3.0
84    unicode = str
85
86try:
87    basestring
88except NameError:  # Python >= 3.0
89    basestring = (str, bytes)
90
91from collections import Iterable
92try:
93    from collections import OrderedDict
94except ImportError:  # Python 2.6 or 3.0
95    try:
96        from ordereddict import OrderedDict
97    except Exception:
98        def OrderedDict(*args):
99            raise NotSupportedError('OrderedDict is not supported')
100
101set_decimal(Decimal)
102
103
104### Module Constants
105
106# compliant with DB API 2.0
107apilevel = '2.0'
108
109# module may be shared, but not connections
110threadsafety = 1
111
112# this module use extended python format codes
113paramstyle = 'pyformat'
114
115# shortcut methods are not supported by default
116# since they have been excluded from DB API 2
117# and are not recommended by the DB SIG.
118
119shortcutmethods = 0
120
121
122### Internal Types Handling
123
124def decimal_type(decimal_type=None):
125    """Get or set global type to be used for decimal values."""
126    global Decimal
127    if decimal_type is not None:
128        _cast['numeric'] = Decimal = decimal_type
129        set_decimal(Decimal)
130    return Decimal
131
132
133def _cast_bool(value):
134    return value[:1] in ('t', 'T')
135
136
137def _cast_money(value):
138    return Decimal(''.join(filter(
139        lambda v: v in '0123456789.-', value)))
140
141
142def _cast_bytea(value):
143    return unescape_bytea(value)
144
145
146def _cast_float(value):
147    try:
148        return float(value)
149    except ValueError:
150        if value == 'NaN':
151            return nan
152        elif value == 'Infinity':
153            return inf
154        elif value == '-Infinity':
155            return -inf
156        raise
157
158
159_cast = {'bool': _cast_bool, 'bytea': _cast_bytea,
160    'int2': int, 'int4': int, 'serial': int,
161    'int8': long, 'oid': long, 'oid8': long,
162    'float4': _cast_float, 'float8': _cast_float,
163    'numeric': Decimal, 'money': _cast_money}
164
165
166def _db_error(msg, cls=DatabaseError):
167    """Return DatabaseError with empty sqlstate attribute."""
168    error = cls(msg)
169    error.sqlstate = None
170    return error
171
172
173def _op_error(msg):
174    """Return OperationalError."""
175    return _db_error(msg, OperationalError)
176
177
178class TypeCache(dict):
179    """Cache for database types."""
180
181    def __init__(self, cnx):
182        """Initialize type cache for connection."""
183        super(TypeCache, self).__init__()
184        self._src = cnx.source()
185
186    @staticmethod
187    def typecast(typ, value):
188        """Cast value to database type."""
189        if value is None:
190            # for NULL values, no typecast is necessary
191            return None
192        cast = _cast.get(typ)
193        if cast is None:
194            # no typecast available or necessary
195            return value
196        else:
197            return cast(value)
198
199    def getdescr(self, oid):
200        """Get name of database type with given oid."""
201        try:
202            return self[oid]
203        except KeyError:
204            self._src.execute(
205                "SELECT typname, typlen "
206                "FROM pg_type WHERE oid=%s" % oid)
207            res = self._src.fetch(1)[0]
208            # The column name is omitted from the return value.
209            # It will have to be prepended by the caller.
210            res = (res[0], None, int(res[1]), None, None, None)
211            self[oid] = res
212            return res
213
214
215class _quotedict(dict):
216    """Dictionary with auto quoting of its items.
217
218    The quote attribute must be set to the desired quote function.
219    """
220
221    def __getitem__(self, key):
222        return self.quote(super(_quotedict, self).__getitem__(key))
223
224
225### Cursor Object
226
227class Cursor(object):
228    """Cursor object."""
229
230    def __init__(self, dbcnx):
231        """Create a cursor object for the database connection."""
232        self.connection = self._dbcnx = dbcnx
233        self._cnx = dbcnx._cnx
234        self._type_cache = dbcnx._type_cache
235        self._src = self._cnx.source()
236        # the official attribute for describing the result columns
237        self.description = None
238        # unofficial attributes for convenience and performance
239        self.colnames = self.coltypes = None
240        if self.row_factory is Cursor.row_factory:
241            # the row factory needs to be determined dynamically
242            self.row_factory = None
243        else:
244            self.build_row_factory = None
245        self.rowcount = -1
246        self.arraysize = 1
247        self.lastrowid = None
248
249    def __iter__(self):
250        """Make cursor compatible to the iteration protocol."""
251        return self
252
253    def __enter__(self):
254        """Enter the runtime context for the cursor object."""
255        return self
256
257    def __exit__(self, et, ev, tb):
258        """Exit the runtime context for the cursor object."""
259        self.close()
260
261    def _quote(self, val):
262        """Quote value depending on its type."""
263        if isinstance(val, (datetime, date, time, timedelta)):
264            val = str(val)
265        if isinstance(val, basestring):
266            if isinstance(val, Binary):
267                val = self._cnx.escape_bytea(val)
268                if bytes is not str:  # Python >= 3.0
269                    val = val.decode('ascii')
270            else:
271                val = self._cnx.escape_string(val)
272            val = "'%s'" % val
273        elif isinstance(val, (int, long)):
274            pass
275        elif isinstance(val, float):
276            if isinf(val):
277                return "'-Infinity'" if val < 0 else "'Infinity'"
278            elif isnan(val):
279                return "'NaN'"
280        elif val is None:
281            val = 'NULL'
282        elif isinstance(val, (list, tuple)):
283            val = '(%s)' % ','.join(map(lambda v: str(self._quote(v)), val))
284        elif Decimal is not float and isinstance(val, Decimal):
285            pass
286        elif hasattr(val, '__pg_repr__'):
287            val = val.__pg_repr__()
288        else:
289            raise InterfaceError(
290                'do not know how to handle type %s' % type(val))
291        return val
292
293    def _quoteparams(self, string, parameters):
294        """Quote parameters.
295
296        This function works for both mappings and sequences.
297        """
298        if isinstance(parameters, dict):
299            parameters = _quotedict(parameters)
300            parameters.quote = self._quote
301        else:
302            parameters = tuple(map(self._quote, parameters))
303        return string % parameters
304
305    def close(self):
306        """Close the cursor object."""
307        self._src.close()
308        self.description = None
309        self.colnames = self.coltypes = None
310        self.rowcount = -1
311        self.lastrowid = None
312
313    def execute(self, operation, parameters=None):
314        """Prepare and execute a database operation (query or command)."""
315
316        # The parameters may also be specified as list of
317        # tuples to e.g. insert multiple rows in a single
318        # operation, but this kind of usage is deprecated:
319        if (parameters and isinstance(parameters, list) and
320                isinstance(parameters[0], tuple)):
321            return self.executemany(operation, parameters)
322        else:
323            # not a list of tuples
324            return self.executemany(operation, [parameters])
325
326    def executemany(self, operation, seq_of_parameters):
327        """Prepare operation and execute it against a parameter sequence."""
328        if not seq_of_parameters:
329            # don't do anything without parameters
330            return
331        self.description = None
332        self.colnames = self.coltypes = None
333        self.rowcount = -1
334        # first try to execute all queries
335        rowcount = 0
336        sql = "BEGIN"
337        try:
338            if not self._dbcnx._tnx:
339                try:
340                    self._cnx.source().execute(sql)
341                except DatabaseError:
342                    raise
343                except Exception:
344                    raise _op_error("can't start transaction")
345                self._dbcnx._tnx = True
346            for parameters in seq_of_parameters:
347                if parameters:
348                    sql = self._quoteparams(operation, parameters)
349                else:
350                    sql = operation
351                rows = self._src.execute(sql)
352                if rows:  # true if not DML
353                    rowcount += rows
354                else:
355                    self.rowcount = -1
356        except DatabaseError:
357            raise
358        except Error as err:
359            raise _db_error("error in '%s': '%s' " % (sql, err))
360        except Exception as err:
361            raise _op_error("internal error in '%s': %s" % (sql, err))
362        # then initialize result raw count and description
363        if self._src.resulttype == RESULT_DQL:
364            self.rowcount = self._src.ntuples
365            getdescr = self._type_cache.getdescr
366            description = [CursorDescription(
367                info[1], *getdescr(info[2])) for info in self._src.listinfo()]
368            self.colnames = [info[0] for info in description]
369            self.coltypes = [info[1] for info in description]
370            self.description = description
371            self.lastrowid = None
372            if self.build_row_factory:
373                self.row_factory = self.build_row_factory()
374        else:
375            self.rowcount = rowcount
376            self.lastrowid = self._src.oidstatus()
377        # return the cursor object, so you can write statements such as
378        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
379        return self
380
381    def fetchone(self):
382        """Fetch the next row of a query result set."""
383        res = self.fetchmany(1, False)
384        try:
385            return res[0]
386        except IndexError:
387            return None
388
389    def fetchall(self):
390        """Fetch all (remaining) rows of a query result."""
391        return self.fetchmany(-1, False)
392
393    def fetchmany(self, size=None, keep=False):
394        """Fetch the next set of rows of a query result.
395
396        The number of rows to fetch per call is specified by the
397        size parameter. If it is not given, the cursor's arraysize
398        determines the number of rows to be fetched. If you set
399        the keep parameter to true, this is kept as new arraysize.
400        """
401        if size is None:
402            size = self.arraysize
403        if keep:
404            self.arraysize = size
405        try:
406            result = self._src.fetch(size)
407        except DatabaseError:
408            raise
409        except Error as err:
410            raise _db_error(str(err))
411        typecast = self._type_cache.typecast
412        return [self.row_factory([typecast(typ, value)
413            for typ, value in zip(self.coltypes, row)]) for row in result]
414
415    def callproc(self, procname, parameters=None):
416        """Call a stored database procedure with the given name.
417
418        The sequence of parameters must contain one entry for each input
419        argument that the procedure expects. The result of the call is the
420        same as this input sequence; replacement of output and input/output
421        parameters in the return value is currently not supported.
422
423        The procedure may also provide a result set as output. These can be
424        requested through the standard fetch methods of the cursor.
425        """
426        n = parameters and len(parameters) or 0
427        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
428        self.execute(query, parameters)
429        return parameters
430
431    def copy_from(self, stream, table,
432            format=None, sep=None, null=None, size=None, columns=None):
433        """Copy data from an input stream to the specified table.
434
435        The input stream can be a file-like object with a read() method or
436        it can also be an iterable returning a row or multiple rows of input
437        on each iteration.
438
439        The format must be text, csv or binary. The sep option sets the
440        column separator (delimiter) used in the non binary formats.
441        The null option sets the textual representation of NULL in the input.
442
443        The size option sets the size of the buffer used when reading data
444        from file-like objects.
445
446        The copy operation can be restricted to a subset of columns. If no
447        columns are specified, all of them will be copied.
448        """
449        binary_format = format == 'binary'
450        try:
451            read = stream.read
452        except AttributeError:
453            if size:
454                raise ValueError("size must only be set for file-like objects")
455            if binary_format:
456                input_type = bytes
457                type_name = 'byte strings'
458            else:
459                input_type = basestring
460                type_name = 'strings'
461
462            if isinstance(stream, basestring):
463                if not isinstance(stream, input_type):
464                    raise ValueError("The input must be %s" % type_name)
465                if not binary_format:
466                    if isinstance(stream, str):
467                        if not stream.endswith('\n'):
468                            stream += '\n'
469                    else:
470                        if not stream.endswith(b'\n'):
471                            stream += b'\n'
472
473                def chunks():
474                    yield stream
475
476            elif isinstance(stream, Iterable):
477
478                def chunks():
479                    for chunk in stream:
480                        if not isinstance(chunk, input_type):
481                            raise ValueError(
482                                "Input stream must consist of %s" % type_name)
483                        if isinstance(chunk, str):
484                            if not chunk.endswith('\n'):
485                                chunk += '\n'
486                        else:
487                            if not chunk.endswith(b'\n'):
488                                chunk += b'\n'
489                        yield chunk
490
491            else:
492                raise TypeError("Need an input stream to copy from")
493        else:
494            if size is None:
495                size = 8192
496            if size > 0:
497                if not isinstance(size, int):
498                    raise TypeError("The size option must be an integer")
499
500                def chunks():
501                    while True:
502                        buffer = read(size)
503                        yield buffer
504                        if not buffer or len(buffer) < size:
505                            break
506
507            else:
508
509                def chunks():
510                    yield read()
511
512        if not table or not isinstance(table, basestring):
513            raise TypeError("Need a table to copy to")
514        if table.lower().startswith('select'):
515                raise ValueError("Must specify a table, not a query")
516        else:
517            table = '"%s"' % (table,)
518        operation = ['copy %s' % (table,)]
519        options = []
520        params = []
521        if format is not None:
522            if not isinstance(format, basestring):
523                raise TypeError("format option must be be a string")
524            if format not in ('text', 'csv', 'binary'):
525                raise ValueError("invalid format")
526            options.append('format %s' % (format,))
527        if sep is not None:
528            if not isinstance(sep, basestring):
529                raise TypeError("sep option must be a string")
530            if format == 'binary':
531                raise ValueError("sep is not allowed with binary format")
532            if len(sep) != 1:
533                raise ValueError("sep must be a single one-byte character")
534            options.append('delimiter %s')
535            params.append(sep)
536        if null is not None:
537            if not isinstance(null, basestring):
538                raise TypeError("null option must be a string")
539            options.append('null %s')
540            params.append(null)
541        if columns:
542            if not isinstance(columns, basestring):
543                columns = ','.join('"%s"' % (col,) for col in columns)
544            operation.append('(%s)' % (columns,))
545        operation.append("from stdin")
546        if options:
547            operation.append('(%s)' % ','.join(options))
548        operation = ' '.join(operation)
549
550        putdata = self._src.putdata
551        self.execute(operation, params)
552
553        try:
554            for chunk in chunks():
555                putdata(chunk)
556        except BaseException as error:
557            self.rowcount = -1
558            # the following call will re-raise the error
559            putdata(error)
560        else:
561            self.rowcount = putdata(None)
562
563        # return the cursor object, so you can chain operations
564        return self
565
566    def copy_to(self, stream, table,
567            format=None, sep=None, null=None, decode=None, columns=None):
568        """Copy data from the specified table to an output stream.
569
570        The output stream can be a file-like object with a write() method or
571        it can also be None, in which case the method will return a generator
572        yielding a row on each iteration.
573
574        Output will be returned as byte strings unless you set decode to true.
575
576        Note that you can also use a select query instead of the table name.
577
578        The format must be text, csv or binary. The sep option sets the
579        column separator (delimiter) used in the non binary formats.
580        The null option sets the textual representation of NULL in the output.
581
582        The copy operation can be restricted to a subset of columns. If no
583        columns are specified, all of them will be copied.
584        """
585        binary_format = format == 'binary'
586        if stream is not None:
587            try:
588                write = stream.write
589            except AttributeError:
590                raise TypeError("need an output stream to copy to")
591        if not table or not isinstance(table, basestring):
592            raise TypeError("need a table to copy to")
593        if table.lower().startswith('select'):
594            if columns:
595                raise ValueError("columns must be specified in the query")
596            table = '(%s)' % (table,)
597        else:
598            table = '"%s"' % (table,)
599        operation = ['copy %s' % (table,)]
600        options = []
601        params = []
602        if format is not None:
603            if not isinstance(format, basestring):
604                raise TypeError("format option must be a string")
605            if format not in ('text', 'csv', 'binary'):
606                raise ValueError("invalid format")
607            options.append('format %s' % (format,))
608        if sep is not None:
609            if not isinstance(sep, basestring):
610                raise TypeError("sep option must be a string")
611            if binary_format:
612                raise ValueError("sep is not allowed with binary format")
613            if len(sep) != 1:
614                raise ValueError("sep must be a single one-byte character")
615            options.append('delimiter %s')
616            params.append(sep)
617        if null is not None:
618            if not isinstance(null, basestring):
619                raise TypeError("null option must be a string")
620            options.append('null %s')
621            params.append(null)
622        if decode is None:
623            if format == 'binary':
624                decode = False
625            else:
626                decode = str is unicode
627        else:
628            if not isinstance(decode, (int, bool)):
629                raise TypeError("decode option must be a boolean")
630            if decode and binary_format:
631                raise ValueError("decode is not allowed with binary format")
632        if columns:
633            if not isinstance(columns, basestring):
634                columns = ','.join('"%s"' % (col,) for col in columns)
635            operation.append('(%s)' % (columns,))
636
637        operation.append("to stdout")
638        if options:
639            operation.append('(%s)' % ','.join(options))
640        operation = ' '.join(operation)
641
642        getdata = self._src.getdata
643        self.execute(operation, params)
644
645        def copy():
646            self.rowcount = 0
647            while True:
648                row = getdata(decode)
649                if isinstance(row, int):
650                    if self.rowcount != row:
651                        self.rowcount = row
652                    break
653                self.rowcount += 1
654                yield row
655
656        if stream is None:
657            # no input stream, return the generator
658            return copy()
659
660        # write the rows to the file-like input stream
661        for row in copy():
662            write(row)
663
664        # return the cursor object, so you can chain operations
665        return self
666
667    def __next__(self):
668        """Return the next row (support for the iteration protocol)."""
669        res = self.fetchone()
670        if res is None:
671            raise StopIteration
672        return res
673
674    # Note that since Python 2.6 the iterator protocol uses __next()__
675    # instead of next(), we keep it only for backward compatibility of pgdb.
676    next = __next__
677
678    @staticmethod
679    def nextset():
680        """Not supported."""
681        raise NotSupportedError("nextset() is not supported")
682
683    @staticmethod
684    def setinputsizes(sizes):
685        """Not supported."""
686        pass  # unsupported, but silently passed
687
688    @staticmethod
689    def setoutputsize(size, column=0):
690        """Not supported."""
691        pass  # unsupported, but silently passed
692
693    @staticmethod
694    def row_factory(row):
695        """Process rows before they are returned.
696
697        You can overwrite this statically with a custom row factory, or
698        you can build a row factory dynamically with build_row_factory().
699
700        For example, you can create a Cursor class that returns rows as
701        Python dictionaries like this:
702
703            class DictCursor(pgdb.Cursor):
704
705                def row_factory(self, row):
706                    return {desc[0]: value
707                        for desc, value in zip(self.description, row)}
708
709            cur = DictCursor(con)  # get one DictCursor instance or
710            con.cursor_type = DictCursor  # always use DictCursor instances
711        """
712        raise NotImplementedError
713
714    def build_row_factory(self):
715        """Build a row factory based on the current description.
716
717        This implementation builds a row factory for creating named tuples.
718        You can overwrite this method if you want to dynamically create
719        different row factories whenever the column description changes.
720        """
721        colnames = self.colnames
722        if colnames:
723            try:
724                try:
725                    return namedtuple('Row', colnames, rename=True)._make
726                except TypeError:  # Python 2.6 and 3.0 do not support rename
727                    colnames = [v if v.isalnum() else 'column_%d' % n
728                             for n, v in enumerate(colnames)]
729                    return namedtuple('Row', colnames)._make
730            except ValueError:  # there is still a problem with the field names
731                colnames = ['column_%d' % n for n in range(len(colnames))]
732                return namedtuple('Row', colnames)._make
733
734
735CursorDescription = namedtuple('CursorDescription',
736    ['name', 'type_code', 'display_size', 'internal_size',
737     'precision', 'scale', 'null_ok'])
738
739
740### Connection Objects
741
742class Connection(object):
743    """Connection object."""
744
745    # expose the exceptions as attributes on the connection object
746    Error = Error
747    Warning = Warning
748    InterfaceError = InterfaceError
749    DatabaseError = DatabaseError
750    InternalError = InternalError
751    OperationalError = OperationalError
752    ProgrammingError = ProgrammingError
753    IntegrityError = IntegrityError
754    DataError = DataError
755    NotSupportedError = NotSupportedError
756
757    def __init__(self, cnx):
758        """Create a database connection object."""
759        self._cnx = cnx  # connection
760        self._tnx = False  # transaction state
761        self._type_cache = TypeCache(cnx)
762        self.cursor_type = Cursor
763        try:
764            self._cnx.source()
765        except Exception:
766            raise _op_error("invalid connection")
767
768    def __enter__(self):
769        """Enter the runtime context for the connection object.
770
771        The runtime context can be used for running transactions.
772        """
773        return self
774
775    def __exit__(self, et, ev, tb):
776        """Exit the runtime context for the connection object.
777
778        This does not close the connection, but it ends a transaction.
779        """
780        if et is None and ev is None and tb is None:
781            self.commit()
782        else:
783            self.rollback()
784
785    def close(self):
786        """Close the connection object."""
787        if self._cnx:
788            if self._tnx:
789                try:
790                    self.rollback()
791                except DatabaseError:
792                    pass
793            self._cnx.close()
794            self._cnx = None
795        else:
796            raise _op_error("connection has been closed")
797
798    def commit(self):
799        """Commit any pending transaction to the database."""
800        if self._cnx:
801            if self._tnx:
802                self._tnx = False
803                try:
804                    self._cnx.source().execute("COMMIT")
805                except DatabaseError:
806                    raise
807                except Exception:
808                    raise _op_error("can't commit")
809        else:
810            raise _op_error("connection has been closed")
811
812    def rollback(self):
813        """Roll back to the start of any pending transaction."""
814        if self._cnx:
815            if self._tnx:
816                self._tnx = False
817                try:
818                    self._cnx.source().execute("ROLLBACK")
819                except DatabaseError:
820                    raise
821                except Exception:
822                    raise _op_error("can't rollback")
823        else:
824            raise _op_error("connection has been closed")
825
826    def cursor(self):
827        """Return a new cursor object using the connection."""
828        if self._cnx:
829            try:
830                return self.cursor_type(self)
831            except Exception:
832                raise _op_error("invalid connection")
833        else:
834            raise _op_error("connection has been closed")
835
836    if shortcutmethods:  # otherwise do not implement and document this
837
838        def execute(self, operation, params=None):
839            """Shortcut method to run an operation on an implicit cursor."""
840            cursor = self.cursor()
841            cursor.execute(operation, params)
842            return cursor
843
844        def executemany(self, operation, param_seq):
845            """Shortcut method to run an operation against a sequence."""
846            cursor = self.cursor()
847            cursor.executemany(operation, param_seq)
848            return cursor
849
850
851### Module Interface
852
853_connect_ = connect
854
855def connect(dsn=None,
856        user=None, password=None,
857        host=None, database=None):
858    """Connect to a database."""
859    # first get params from DSN
860    dbport = -1
861    dbhost = ""
862    dbbase = ""
863    dbuser = ""
864    dbpasswd = ""
865    dbopt = ""
866    try:
867        params = dsn.split(":")
868        dbhost = params[0]
869        dbbase = params[1]
870        dbuser = params[2]
871        dbpasswd = params[3]
872        dbopt = params[4]
873    except (AttributeError, IndexError, TypeError):
874        pass
875
876    # override if necessary
877    if user is not None:
878        dbuser = user
879    if password is not None:
880        dbpasswd = password
881    if database is not None:
882        dbbase = database
883    if host is not None:
884        try:
885            params = host.split(":")
886            dbhost = params[0]
887            dbport = int(params[1])
888        except (AttributeError, IndexError, TypeError, ValueError):
889            pass
890
891    # empty host is localhost
892    if dbhost == "":
893        dbhost = None
894    if dbuser == "":
895        dbuser = None
896
897    # open the connection
898    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
899    return Connection(cnx)
900
901
902### Types Handling
903
904class Type(frozenset):
905    """Type class for a couple of PostgreSQL data types.
906
907    PostgreSQL is object-oriented: types are dynamic.
908    We must thus use type names as internal type codes.
909    """
910
911    def __new__(cls, values):
912        if isinstance(values, basestring):
913            values = values.split()
914        return super(Type, cls).__new__(cls, values)
915
916    def __eq__(self, other):
917        if isinstance(other, basestring):
918            return other in self
919        else:
920            return super(Type, self).__eq__(other)
921
922    def __ne__(self, other):
923        if isinstance(other, basestring):
924            return other not in self
925        else:
926            return super(Type, self).__ne__(other)
927
928
929# Mandatory type objects defined by DB-API 2 specs:
930
931STRING = Type('char bpchar name text varchar')
932BINARY = Type('bytea')
933NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
934DATETIME = Type('date time timetz timestamp timestamptz datetime abstime'
935    ' interval tinterval timespan reltime')
936ROWID = Type('oid oid8')
937
938
939# Additional type objects (more specific):
940
941BOOL = Type('bool')
942SMALLINT = Type('int2')
943INTEGER = Type('int2 int4 int8 serial')
944LONG = Type('int8')
945FLOAT = Type('float4 float8')
946NUMERIC = Type('numeric')
947MONEY = Type('money')
948DATE = Type('date')
949TIME = Type('time timetz')
950TIMESTAMP = Type('timestamp timestamptz datetime abstime')
951INTERVAL = Type('interval tinterval timespan reltime')
952
953
954# Mandatory type helpers defined by DB-API 2 specs:
955
956def Date(year, month, day):
957    """Construct an object holding a date value."""
958    return date(year, month, day)
959
960
961def Time(hour, minute=0, second=0, microsecond=0):
962    """Construct an object holding a time value."""
963    return time(hour, minute, second, microsecond)
964
965
966def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
967    """construct an object holding a time stamp valu."""
968    return datetime(year, month, day, hour, minute, second, microsecond)
969
970
971def DateFromTicks(ticks):
972    """Construct an object holding a date value from the given ticks value."""
973    return Date(*localtime(ticks)[:3])
974
975
976def TimeFromTicks(ticks):
977    """construct an object holding a time value from the given ticks value."""
978    return Time(*localtime(ticks)[3:6])
979
980
981def TimestampFromTicks(ticks):
982    """construct an object holding a time stamp from the given ticks value."""
983    return Timestamp(*localtime(ticks)[:6])
984
985
986class Binary(bytes):
987    """construct an object capable of holding a binary (long) string value."""
988
989
990# If run as script, print some information:
991
992if __name__ == '__main__':
993    print('PyGreSQL version', version)
994    print('')
995    print(__doc__)
Note: See TracBrowser for help on using the repository browser.