source: trunk/module/pgdb.py @ 697

Last change on this file since 697 was 697, checked in by cito, 3 years ago

Test rowcount attribute with copy methods

  • Property svn:keywords set to Id
File size: 32.0 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 697 2016-01-07 16:59:10Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64
65"""
66
67from __future__ import print_function
68
69from _pg import *
70
71from datetime import date, time, datetime, timedelta
72from time import localtime
73from decimal import Decimal
74from math import isnan, isinf
75from collections import namedtuple
76
77try:
78    long
79except NameError:  # Python >= 3.0
80    long = int
81
82try:
83    unicode
84except NameError:  # Python >= 3.0
85    unicode = str
86
87try:
88    basestring
89except NameError:  # Python >= 3.0
90    basestring = (str, bytes)
91
92from collections import Iterable
93try:
94    from collections import OrderedDict
95except ImportError:  # Python 2.6 or 3.0
96    try:
97        from ordereddict import OrderedDict
98    except Exception:
99        def OrderedDict(*args):
100            raise NotSupportedError('OrderedDict is not supported')
101
102set_decimal(Decimal)
103
104
105### Module Constants
106
107# compliant with DB API 2.0
108apilevel = '2.0'
109
110# module may be shared, but not connections
111threadsafety = 1
112
113# this module use extended python format codes
114paramstyle = 'pyformat'
115
116# shortcut methods are not supported by default
117# since they have been excluded from DB API 2
118# and are not recommended by the DB SIG.
119
120shortcutmethods = 0
121
122
123### Internal Types Handling
124
125def decimal_type(decimal_type=None):
126    """Get or set global type to be used for decimal values."""
127    global Decimal
128    if decimal_type is not None:
129        _cast['numeric'] = Decimal = decimal_type
130        set_decimal(Decimal)
131    return Decimal
132
133
134def _cast_bool(value):
135    return value[:1] in ('t', 'T')
136
137
138def _cast_money(value):
139    return Decimal(''.join(filter(
140        lambda v: v in '0123456789.-', value)))
141
142
143def _cast_bytea(value):
144    return unescape_bytea(value)
145
146
147def _cast_float(value):
148    try:
149        return float(value)
150    except ValueError:
151        if value == 'NaN':
152            return nan
153        elif value == 'Infinity':
154            return inf
155        elif value == '-Infinity':
156            return -inf
157        raise
158
159
160_cast = {'bool': _cast_bool, 'bytea': _cast_bytea,
161    'int2': int, 'int4': int, 'serial': int,
162    'int8': long, 'oid': long, 'oid8': long,
163    'float4': _cast_float, 'float8': _cast_float,
164    'numeric': Decimal, 'money': _cast_money}
165
166
167def _db_error(msg, cls=DatabaseError):
168    """Returns DatabaseError with empty sqlstate attribute."""
169    error = cls(msg)
170    error.sqlstate = None
171    return error
172
173
174def _op_error(msg):
175    """Returns OperationalError."""
176    return _db_error(msg, OperationalError)
177
178
179class TypeCache(dict):
180    """Cache for database types."""
181
182    def __init__(self, cnx):
183        """Initialize type cache for connection."""
184        super(TypeCache, self).__init__()
185        self._src = cnx.source()
186
187    @staticmethod
188    def typecast(typ, value):
189        """Cast value to database type."""
190        if value is None:
191            # for NULL values, no typecast is necessary
192            return None
193        cast = _cast.get(typ)
194        if cast is None:
195            # no typecast available or necessary
196            return value
197        else:
198            return cast(value)
199
200    def getdescr(self, oid):
201        """Get name of database type with given oid."""
202        try:
203            return self[oid]
204        except KeyError:
205            self._src.execute(
206                "SELECT typname, typlen "
207                "FROM pg_type WHERE oid=%s" % oid)
208            res = self._src.fetch(1)[0]
209            # The column name is omitted from the return value.
210            # It will have to be prepended by the caller.
211            res = (res[0], None, int(res[1]), None, None, None)
212            self[oid] = res
213            return res
214
215
216class _quotedict(dict):
217    """Dictionary with auto quoting of its items.
218
219    The quote attribute must be set to the desired quote function.
220
221    """
222
223    def __getitem__(self, key):
224        return self.quote(super(_quotedict, self).__getitem__(key))
225
226
227### Cursor Object
228
229class Cursor(object):
230    """Cursor object."""
231
232    def __init__(self, dbcnx):
233        """Create a cursor object for the database connection."""
234        self.connection = self._dbcnx = dbcnx
235        self._cnx = dbcnx._cnx
236        self._type_cache = dbcnx._type_cache
237        self._src = self._cnx.source()
238        # the official attribute for describing the result columns
239        self.description = None
240        # unofficial attributes for convenience and performance
241        self.colnames = self.coltypes = None
242        if self.row_factory is Cursor.row_factory:
243            # the row factory needs to be determined dynamically
244            self.row_factory = None
245        else:
246            self.build_row_factory = None
247        self.rowcount = -1
248        self.arraysize = 1
249        self.lastrowid = None
250
251    def __iter__(self):
252        """Make cursors compatible to the iteration protocol."""
253        return self
254
255    def __enter__(self):
256        """Enter the runtime context for the cursor object."""
257        return self
258
259    def __exit__(self, et, ev, tb):
260        """Exit the runtime context for the cursor object."""
261        self.close()
262
263    def _quote(self, val):
264        """Quote value depending on its type."""
265        if isinstance(val, (datetime, date, time, timedelta)):
266            val = str(val)
267        if isinstance(val, basestring):
268            if isinstance(val, Binary):
269                val = self._cnx.escape_bytea(val)
270                if bytes is not str:  # Python >= 3.0
271                    val = val.decode('ascii')
272            else:
273                val = self._cnx.escape_string(val)
274            val = "'%s'" % val
275        elif isinstance(val, (int, long)):
276            pass
277        elif isinstance(val, float):
278            if isinf(val):
279                return "'-Infinity'" if val < 0 else "'Infinity'"
280            elif isnan(val):
281                return "'NaN'"
282        elif val is None:
283            val = 'NULL'
284        elif isinstance(val, (list, tuple)):
285            val = '(%s)' % ','.join(map(lambda v: str(self._quote(v)), val))
286        elif Decimal is not float and isinstance(val, Decimal):
287            pass
288        elif hasattr(val, '__pg_repr__'):
289            val = val.__pg_repr__()
290        else:
291            raise InterfaceError(
292                'do not know how to handle type %s' % type(val))
293        return val
294
295    def _quoteparams(self, string, parameters):
296        """Quote parameters.
297
298        This function works for both mappings and sequences.
299
300        """
301        if isinstance(parameters, dict):
302            parameters = _quotedict(parameters)
303            parameters.quote = self._quote
304        else:
305            parameters = tuple(map(self._quote, parameters))
306        return string % parameters
307
308    def close(self):
309        """Close the cursor object."""
310        self._src.close()
311        self.description = None
312        self.colnames = self.coltypes = None
313        self.rowcount = -1
314        self.lastrowid = None
315
316    def execute(self, operation, parameters=None):
317        """Prepare and execute a database operation (query or command)."""
318
319        # The parameters may also be specified as list of
320        # tuples to e.g. insert multiple rows in a single
321        # operation, but this kind of usage is deprecated:
322        if (parameters and isinstance(parameters, list) and
323                isinstance(parameters[0], tuple)):
324            return self.executemany(operation, parameters)
325        else:
326            # not a list of tuples
327            return self.executemany(operation, [parameters])
328
329    def executemany(self, operation, seq_of_parameters):
330        """Prepare operation and execute it against a parameter sequence."""
331        if not seq_of_parameters:
332            # don't do anything without parameters
333            return
334        self.description = None
335        self.colnames = self.coltypes = None
336        self.rowcount = -1
337        # first try to execute all queries
338        rowcount = 0
339        sql = "BEGIN"
340        try:
341            if not self._dbcnx._tnx:
342                try:
343                    self._cnx.source().execute(sql)
344                except DatabaseError:
345                    raise
346                except Exception:
347                    raise _op_error("can't start transaction")
348                self._dbcnx._tnx = True
349            for parameters in seq_of_parameters:
350                if parameters:
351                    sql = self._quoteparams(operation, parameters)
352                else:
353                    sql = operation
354                rows = self._src.execute(sql)
355                if rows:  # true if not DML
356                    rowcount += rows
357                else:
358                    self.rowcount = -1
359        except DatabaseError:
360            raise
361        except Error as err:
362            raise _db_error("error in '%s': '%s' " % (sql, err))
363        except Exception as err:
364            raise _op_error("internal error in '%s': %s" % (sql, err))
365        # then initialize result raw count and description
366        if self._src.resulttype == RESULT_DQL:
367            self.rowcount = self._src.ntuples
368            getdescr = self._type_cache.getdescr
369            description = [CursorDescription(
370                info[1], *getdescr(info[2])) for info in self._src.listinfo()]
371            self.colnames = [info[0] for info in description]
372            self.coltypes = [info[1] for info in description]
373            self.description = description
374            self.lastrowid = None
375            if self.build_row_factory:
376                self.row_factory = self.build_row_factory()
377        else:
378            self.rowcount = rowcount
379            self.lastrowid = self._src.oidstatus()
380        # return the cursor object, so you can write statements such as
381        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
382        return self
383
384    def fetchone(self):
385        """Fetch the next row of a query result set."""
386        res = self.fetchmany(1, False)
387        try:
388            return res[0]
389        except IndexError:
390            return None
391
392    def fetchall(self):
393        """Fetch all (remaining) rows of a query result."""
394        return self.fetchmany(-1, False)
395
396    def fetchmany(self, size=None, keep=False):
397        """Fetch the next set of rows of a query result.
398
399        The number of rows to fetch per call is specified by the
400        size parameter. If it is not given, the cursor's arraysize
401        determines the number of rows to be fetched. If you set
402        the keep parameter to true, this is kept as new arraysize.
403
404        """
405        if size is None:
406            size = self.arraysize
407        if keep:
408            self.arraysize = size
409        try:
410            result = self._src.fetch(size)
411        except DatabaseError:
412            raise
413        except Error as err:
414            raise _db_error(str(err))
415        typecast = self._type_cache.typecast
416        return [self.row_factory([typecast(typ, value)
417            for typ, value in zip(self.coltypes, row)]) for row in result]
418
419    def callproc(self, procname, parameters=None):
420        """Call a stored database procedure with the given name.
421
422        The sequence of parameters must contain one entry for each input
423        argument that the procedure expects. The result of the call is the
424        same as this input sequence; replacement of output and input/output
425        parameters in the return value is currently not supported.
426
427        The procedure may also provide a result set as output. These can be
428        requested through the standard fetch methods of the cursor.
429
430        """
431        n = parameters and len(parameters) or 0
432        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
433        self.execute(query, parameters)
434        return parameters
435
436    def copy_from(self, stream, table,
437            format=None, sep=None, null=None, size=None, columns=None):
438        """Copy data from an input stream to the specified table.
439
440        The input stream can be a file-like object with a read() method or
441        it can also be an iterable returning a row or multiple rows of input
442        on each iteration.
443
444        The format must be text, csv or binary. The sep option sets the
445        column separator (delimiter) used in the non binary formats.
446        The null option sets the textual representation of NULL in the input.
447
448        The size option sets the size of the buffer used when reading data
449        from file-like objects.
450
451        The copy operation can be restricted to a subset of columns. If no
452        columns are specified, all of them will be copied.
453
454        """
455        binary_format = format == 'binary'
456        try:
457            read = stream.read
458        except AttributeError:
459            if size:
460                raise ValueError("size must only be set for file-like objects")
461            if binary_format:
462                input_type = bytes
463                type_name = 'byte strings'
464            else:
465                input_type = basestring
466                type_name = 'strings'
467
468            if isinstance(stream, basestring):
469                if not isinstance(stream, input_type):
470                    raise ValueError("the input must be %s" % type_name)
471                if not binary_format:
472                    if isinstance(stream, str):
473                        if not stream.endswith('\n'):
474                            stream += '\n'
475                    else:
476                        if not stream.endswith(b'\n'):
477                            stream += b'\n'
478
479                def chunks():
480                    yield stream
481
482            elif isinstance(stream, Iterable):
483
484                def chunks():
485                    for chunk in stream:
486                        if not isinstance(chunk, input_type):
487                            raise ValueError(
488                                "input stream must consist of %s" % type_name)
489                        if isinstance(chunk, str):
490                            if not chunk.endswith('\n'):
491                                chunk += '\n'
492                        else:
493                            if not chunk.endswith(b'\n'):
494                                chunk += b'\n'
495                        yield chunk
496
497            else:
498                raise TypeError("need an input stream to copy from")
499        else:
500            if size is None:
501                size = 8192
502            if size > 0:
503                if not isinstance(size, int):
504                    raise TypeError("the size option must be an integer")
505
506                def chunks():
507                    while True:
508                        buffer = read(size)
509                        yield buffer
510                        if not buffer or len(buffer) < size:
511                            break
512
513            else:
514
515                def chunks():
516                    yield read()
517
518        if not table or not isinstance(table, basestring):
519            raise TypeError("need a table to copy to")
520        if table.lower().startswith('select'):
521                raise ValueError("must specify a table, not a query")
522        else:
523            table = '"%s"' % (table,)
524        operation = ['copy %s' % (table,)]
525        options = []
526        params = []
527        if format is not None:
528            if not isinstance(format, basestring):
529                raise TypeError("the format options be a string")
530            if format not in ('text', 'csv', 'binary'):
531                raise ValueError("invalid format")
532            options.append('format %s' % (format,))
533        if sep is not None:
534            if not isinstance(sep, basestring):
535                raise TypeError("the sep option must be a string")
536            if format == 'binary':
537                raise ValueError("sep is not allowed with binary format")
538            if len(sep) != 1:
539                raise ValueError("sep must be a single one-byte character")
540            options.append('delimiter %s')
541            params.append(sep)
542        if null is not None:
543            if not isinstance(null, basestring):
544                raise TypeError("the null option must be a string")
545            options.append('null %s')
546            params.append(null)
547        if columns:
548            if not isinstance(columns, basestring):
549                columns = ','.join('"%s"' % (col,) for col in columns)
550            operation.append('(%s)' % (columns,))
551        operation.append("from stdin")
552        if options:
553            operation.append('(%s)' % ','.join(options))
554        operation = ' '.join(operation)
555
556        putdata = self._src.putdata
557        self.execute(operation, params)
558
559        try:
560            for chunk in chunks():
561                putdata(chunk)
562        except BaseException as error:
563            self.rowcount = -1
564            # the following call will re-raise the error
565            putdata(error)
566        else:
567            self.rowcount = putdata(None)
568
569        # return the cursor object, so you can chain operations
570        return self
571
572    def copy_to(self, stream, table,
573            format=None, sep=None, null=None, decode=None, columns=None):
574        """Copy data from the specified table to an output stream.
575
576        The output stream can be a file-like object with a write() method or
577        it can also be None, in which case the method will return a generator
578        yielding a row on each iteration.
579
580        Output will be returned as byte strings unless you set decode to true.
581
582        Note that you can also use a select query instead of the table name.
583
584        The format must be text, csv or binary. The sep option sets the
585        column separator (delimiter) used in the non binary formats.
586        The null option sets the textual representation of NULL in the output.
587
588        The copy operation can be restricted to a subset of columns. If no
589        columns are specified, all of them will be copied.
590
591        """
592        binary_format = format == 'binary'
593        if stream is not None:
594            try:
595                write = stream.write
596            except AttributeError:
597                raise TypeError("need an output stream to copy to")
598        if not table or not isinstance(table, basestring):
599            raise TypeError("need a table to copy to")
600        if table.lower().startswith('select'):
601            if columns:
602                raise ValueError("columns must be specified in the query")
603            table = '(%s)' % (table,)
604        else:
605            table = '"%s"' % (table,)
606        operation = ['copy %s' % (table,)]
607        options = []
608        params = []
609        if format is not None:
610            if not isinstance(format, basestring):
611                raise TypeError("the format options be a string")
612            if format not in ('text', 'csv', 'binary'):
613                raise ValueError("invalid format")
614            options.append('format %s' % (format,))
615        if sep is not None:
616            if not isinstance(sep, basestring):
617                raise TypeError("the sep option must be a string")
618            if binary_format:
619                raise ValueError("sep is not allowed with binary format")
620            if len(sep) != 1:
621                raise ValueError("sep must be a single one-byte character")
622            options.append('delimiter %s')
623            params.append(sep)
624        if null is not None:
625            if not isinstance(null, basestring):
626                raise TypeError("the null option must be a string")
627            options.append('null %s')
628            params.append(null)
629        if decode is None:
630            if format == 'binary':
631                decode = False
632            else:
633                decode = str is unicode
634        else:
635            if not isinstance(decode, (int, bool)):
636                raise TypeError("the decode option must be a boolean")
637            if decode and binary_format:
638                raise ValueError("decode is not allowed with binary format")
639        if columns:
640            if not isinstance(columns, basestring):
641                columns = ','.join('"%s"' % (col,) for col in columns)
642            operation.append('(%s)' % (columns,))
643
644        operation.append("to stdout")
645        if options:
646            operation.append('(%s)' % ','.join(options))
647        operation = ' '.join(operation)
648
649        getdata = self._src.getdata
650        self.execute(operation, params)
651
652        def copy():
653            self.rowcount = 0
654            while True:
655                row = getdata(decode)
656                if isinstance(row, int):
657                    if self.rowcount != row:
658                        self.rowcount = row
659                    break
660                self.rowcount += 1
661                yield row
662
663        if stream is None:
664            # no input stream, return the generator
665            return copy()
666
667        # write the rows to the file-like input stream
668        for row in copy():
669            write(row)
670
671        # return the cursor object, so you can chain operations
672        return self
673
674    def __next__(self):
675        """Return the next row (support for the iteration protocol)."""
676        res = self.fetchone()
677        if res is None:
678            raise StopIteration
679        return res
680
681    # Note that since Python 2.6 the iterator protocol uses __next()__
682    # instead of next(), we keep it only for backward compatibility of pgdb.
683    next = __next__
684
685    @staticmethod
686    def nextset():
687        """Not supported."""
688        raise NotSupportedError("nextset() is not supported")
689
690    @staticmethod
691    def setinputsizes(sizes):
692        """Not supported."""
693        pass  # unsupported, but silently passed
694
695    @staticmethod
696    def setoutputsize(size, column=0):
697        """Not supported."""
698        pass  # unsupported, but silently passed
699
700    @staticmethod
701    def row_factory(row):
702        """Process rows before they are returned.
703
704        You can overwrite this statically with a custom row factory, or
705        you can build a row factory dynamically with build_row_factory().
706
707        For example, you can create a Cursor class that returns rows as
708        Python dictionaries like this:
709
710            class DictCursor(pgdb.Cursor):
711
712                def row_factory(self, row):
713                    return {desc[0]: value
714                        for desc, value in zip(self.description, row)}
715
716            cur = DictCursor(con)  # get one DictCursor instance or
717            con.cursor_type = DictCursor  # always use DictCursor instances
718
719        """
720        raise NotImplementedError
721
722    def build_row_factory(self):
723        """Build a row factory based on the current description.
724
725        This implementation builds a row factory for creating named tuples.
726        You can overwrite this method if you want to dynamically create
727        different row factories whenever the column description changes.
728
729        """
730        colnames = self.colnames
731        if colnames:
732            try:
733                try:
734                    return namedtuple('Row', colnames, rename=True)._make
735                except TypeError:  # Python 2.6 and 3.0 do not support rename
736                    colnames = [v if v.isalnum() else 'column_%d' % n
737                             for n, v in enumerate(colnames)]
738                    return namedtuple('Row', colnames)._make
739            except ValueError:  # there is still a problem with the field names
740                colnames = ['column_%d' % n for n in range(len(colnames))]
741                return namedtuple('Row', colnames)._make
742
743
744CursorDescription = namedtuple('CursorDescription',
745    ['name', 'type_code', 'display_size', 'internal_size',
746     'precision', 'scale', 'null_ok'])
747
748
749### Connection Objects
750
751class Connection(object):
752    """Connection object."""
753
754    # expose the exceptions as attributes on the connection object
755    Error = Error
756    Warning = Warning
757    InterfaceError = InterfaceError
758    DatabaseError = DatabaseError
759    InternalError = InternalError
760    OperationalError = OperationalError
761    ProgrammingError = ProgrammingError
762    IntegrityError = IntegrityError
763    DataError = DataError
764    NotSupportedError = NotSupportedError
765
766    def __init__(self, cnx):
767        """Create a database connection object."""
768        self._cnx = cnx  # connection
769        self._tnx = False  # transaction state
770        self._type_cache = TypeCache(cnx)
771        self.cursor_type = Cursor
772        try:
773            self._cnx.source()
774        except Exception:
775            raise _op_error("invalid connection")
776
777    def __enter__(self):
778        """Enter the runtime context for the connection object.
779
780        The runtime context can be used for running transactions.
781
782        """
783        return self
784
785    def __exit__(self, et, ev, tb):
786        """Exit the runtime context for the connection object.
787
788        This does not close the connection, but it ends a transaction.
789
790        """
791        if et is None and ev is None and tb is None:
792            self.commit()
793        else:
794            self.rollback()
795
796    def close(self):
797        """Close the connection object."""
798        if self._cnx:
799            if self._tnx:
800                try:
801                    self.rollback()
802                except DatabaseError:
803                    pass
804            self._cnx.close()
805            self._cnx = None
806        else:
807            raise _op_error("connection has been closed")
808
809    def commit(self):
810        """Commit any pending transaction to the database."""
811        if self._cnx:
812            if self._tnx:
813                self._tnx = False
814                try:
815                    self._cnx.source().execute("COMMIT")
816                except DatabaseError:
817                    raise
818                except Exception:
819                    raise _op_error("can't commit")
820        else:
821            raise _op_error("connection has been closed")
822
823    def rollback(self):
824        """Roll back to the start of any pending transaction."""
825        if self._cnx:
826            if self._tnx:
827                self._tnx = False
828                try:
829                    self._cnx.source().execute("ROLLBACK")
830                except DatabaseError:
831                    raise
832                except Exception:
833                    raise _op_error("can't rollback")
834        else:
835            raise _op_error("connection has been closed")
836
837    def cursor(self):
838        """Return a new cursor object using the connection."""
839        if self._cnx:
840            try:
841                return self.cursor_type(self)
842            except Exception:
843                raise _op_error("invalid connection")
844        else:
845            raise _op_error("connection has been closed")
846
847    if shortcutmethods:  # otherwise do not implement and document this
848
849        def execute(self, operation, params=None):
850            """Shortcut method to run an operation on an implicit cursor."""
851            cursor = self.cursor()
852            cursor.execute(operation, params)
853            return cursor
854
855        def executemany(self, operation, param_seq):
856            """Shortcut method to run an operation against a sequence."""
857            cursor = self.cursor()
858            cursor.executemany(operation, param_seq)
859            return cursor
860
861
862### Module Interface
863
864_connect_ = connect
865
866def connect(dsn=None,
867        user=None, password=None,
868        host=None, database=None):
869    """Connects to a database."""
870    # first get params from DSN
871    dbport = -1
872    dbhost = ""
873    dbbase = ""
874    dbuser = ""
875    dbpasswd = ""
876    dbopt = ""
877    try:
878        params = dsn.split(":")
879        dbhost = params[0]
880        dbbase = params[1]
881        dbuser = params[2]
882        dbpasswd = params[3]
883        dbopt = params[4]
884    except (AttributeError, IndexError, TypeError):
885        pass
886
887    # override if necessary
888    if user is not None:
889        dbuser = user
890    if password is not None:
891        dbpasswd = password
892    if database is not None:
893        dbbase = database
894    if host is not None:
895        try:
896            params = host.split(":")
897            dbhost = params[0]
898            dbport = int(params[1])
899        except (AttributeError, IndexError, TypeError, ValueError):
900            pass
901
902    # empty host is localhost
903    if dbhost == "":
904        dbhost = None
905    if dbuser == "":
906        dbuser = None
907
908    # open the connection
909    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
910    return Connection(cnx)
911
912
913### Types Handling
914
915class Type(frozenset):
916    """Type class for a couple of PostgreSQL data types.
917
918    PostgreSQL is object-oriented: types are dynamic.
919    We must thus use type names as internal type codes.
920
921    """
922
923    def __new__(cls, values):
924        if isinstance(values, basestring):
925            values = values.split()
926        return super(Type, cls).__new__(cls, values)
927
928    def __eq__(self, other):
929        if isinstance(other, basestring):
930            return other in self
931        else:
932            return super(Type, self).__eq__(other)
933
934    def __ne__(self, other):
935        if isinstance(other, basestring):
936            return other not in self
937        else:
938            return super(Type, self).__ne__(other)
939
940
941# Mandatory type objects defined by DB-API 2 specs:
942
943STRING = Type('char bpchar name text varchar')
944BINARY = Type('bytea')
945NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
946DATETIME = Type('date time timetz timestamp timestamptz datetime abstime'
947    ' interval tinterval timespan reltime')
948ROWID = Type('oid oid8')
949
950
951# Additional type objects (more specific):
952
953BOOL = Type('bool')
954SMALLINT = Type('int2')
955INTEGER = Type('int2 int4 int8 serial')
956LONG = Type('int8')
957FLOAT = Type('float4 float8')
958NUMERIC = Type('numeric')
959MONEY = Type('money')
960DATE = Type('date')
961TIME = Type('time timetz')
962TIMESTAMP = Type('timestamp timestamptz datetime abstime')
963INTERVAL = Type('interval tinterval timespan reltime')
964
965
966# Mandatory type helpers defined by DB-API 2 specs:
967
968def Date(year, month, day):
969    """Construct an object holding a date value."""
970    return date(year, month, day)
971
972
973def Time(hour, minute=0, second=0, microsecond=0):
974    """Construct an object holding a time value."""
975    return time(hour, minute, second, microsecond)
976
977
978def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
979    """construct an object holding a time stamp value."""
980    return datetime(year, month, day, hour, minute, second, microsecond)
981
982
983def DateFromTicks(ticks):
984    """Construct an object holding a date value from the given ticks value."""
985    return Date(*localtime(ticks)[:3])
986
987
988def TimeFromTicks(ticks):
989    """construct an object holding a time value from the given ticks value."""
990    return Time(*localtime(ticks)[3:6])
991
992
993def TimestampFromTicks(ticks):
994    """construct an object holding a time stamp from the given ticks value."""
995    return Timestamp(*localtime(ticks)[:6])
996
997
998class Binary(bytes):
999    """construct an object capable of holding a binary (long) string value."""
1000
1001
1002# If run as script, print some information:
1003
1004if __name__ == '__main__':
1005    print('PyGreSQL version', version)
1006    print('')
1007    print(__doc__)
Note: See TracBrowser for help on using the repository browser.