source: trunk/pgdb.py @ 774

Last change on this file since 774 was 774, checked in by cito, 4 years ago

Add support for JSON and JSONB to pg and pgdb

This adds all necessary functions to make PyGreSQL automatically
convert between JSON columns and Python objects representing them.

The documentation has also been updated, see there for the details.

Also, tuples automatically bind to ROW expressions in pgdb now.

  • Property svn:keywords set to Id
File size: 32.5 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 774 2016-01-21 18:49:28Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74from collections import namedtuple
75from json import loads as jsondecode, dumps as jsonencode
76
77try:
78    long
79except NameError:  # Python >= 3.0
80    long = int
81
82try:
83    unicode
84except NameError:  # Python >= 3.0
85    unicode = str
86
87try:
88    basestring
89except NameError:  # Python >= 3.0
90    basestring = (str, bytes)
91
92from collections import Iterable
93try:
94    from collections import OrderedDict
95except ImportError:  # Python 2.6 or 3.0
96    try:
97        from ordereddict import OrderedDict
98    except Exception:
99        def OrderedDict(*args):
100            raise NotSupportedError('OrderedDict is not supported')
101
102
103### Module Constants
104
105# compliant with DB API 2.0
106apilevel = '2.0'
107
108# module may be shared, but not connections
109threadsafety = 1
110
111# this module use extended python format codes
112paramstyle = 'pyformat'
113
114# shortcut methods have been excluded from DB API 2 and
115# are not recommended by the DB SIG, but they can be handy
116shortcutmethods = 1
117
118
119### Internal Types Handling
120
121def decimal_type(decimal_type=None):
122    """Get or set global type to be used for decimal values."""
123    global Decimal
124    if decimal_type is not None:
125        _cast['numeric'] = Decimal = decimal_type
126    return Decimal
127
128
129def _cast_bool(value):
130    return value[:1] in ('t', 'T')
131
132
133def _cast_money(value):
134    return Decimal(''.join(filter(
135        lambda v: v in '0123456789.-', value)))
136
137
138_cast = {'bool': _cast_bool, 'bytea': unescape_bytea,
139    'int2': int, 'int4': int, 'serial': int,
140    'int8': long, 'json': jsondecode, 'jsonb': jsondecode,
141    'oid': long, 'oid8': long,
142    'float4': float, 'float8': float,
143    'numeric': Decimal, 'money': _cast_money}
144
145
146def _db_error(msg, cls=DatabaseError):
147    """Return DatabaseError with empty sqlstate attribute."""
148    error = cls(msg)
149    error.sqlstate = None
150    return error
151
152
153def _op_error(msg):
154    """Return OperationalError."""
155    return _db_error(msg, OperationalError)
156
157
158class TypeCache(dict):
159    """Cache for database types."""
160
161    def __init__(self, cnx):
162        """Initialize type cache for connection."""
163        super(TypeCache, self).__init__()
164        self._src = cnx.source()
165
166    @staticmethod
167    def typecast(typ, value):
168        """Cast value to database type."""
169        if value is None:
170            # for NULL values, no typecast is necessary
171            return None
172        cast = _cast.get(typ)
173        if cast is None:
174            # no typecast available or necessary
175            return value
176        else:
177            return cast(value)
178
179    def getdescr(self, oid):
180        """Get name of database type with given oid."""
181        try:
182            return self[oid]
183        except KeyError:
184            self._src.execute(
185                "SELECT typname, typlen "
186                "FROM pg_type WHERE oid=%s" % oid)
187            res = self._src.fetch(1)[0]
188            # The column name is omitted from the return value.
189            # It will have to be prepended by the caller.
190            res = (res[0], None, int(res[1]), None, None, None)
191            self[oid] = res
192            return res
193
194
195class _quotedict(dict):
196    """Dictionary with auto quoting of its items.
197
198    The quote attribute must be set to the desired quote function.
199    """
200
201    def __getitem__(self, key):
202        return self.quote(super(_quotedict, self).__getitem__(key))
203
204
205### Cursor Object
206
207class Cursor(object):
208    """Cursor object."""
209
210    def __init__(self, dbcnx):
211        """Create a cursor object for the database connection."""
212        self.connection = self._dbcnx = dbcnx
213        self._cnx = dbcnx._cnx
214        self._type_cache = dbcnx._type_cache
215        self._src = self._cnx.source()
216        # the official attribute for describing the result columns
217        self.description = None
218        # unofficial attributes for convenience and performance
219        self.colnames = self.coltypes = None
220        if self.row_factory is Cursor.row_factory:
221            # the row factory needs to be determined dynamically
222            self.row_factory = None
223        else:
224            self.build_row_factory = None
225        self.rowcount = -1
226        self.arraysize = 1
227        self.lastrowid = None
228
229    def __iter__(self):
230        """Make cursor compatible to the iteration protocol."""
231        return self
232
233    def __enter__(self):
234        """Enter the runtime context for the cursor object."""
235        return self
236
237    def __exit__(self, et, ev, tb):
238        """Exit the runtime context for the cursor object."""
239        self.close()
240
241    def _quote(self, val):
242        """Quote value depending on its type."""
243        if isinstance(val, (datetime, date, time, timedelta, Json)):
244            val = str(val)
245        if isinstance(val, basestring):
246            if isinstance(val, Binary):
247                val = self._cnx.escape_bytea(val)
248                if bytes is not str:  # Python >= 3.0
249                    val = val.decode('ascii')
250            else:
251                val = self._cnx.escape_string(val)
252            val = "'%s'" % val
253        elif isinstance(val, (int, long)):
254            pass
255        elif isinstance(val, float):
256            if isinf(val):
257                return "'-Infinity'" if val < 0 else "'Infinity'"
258            elif isnan(val):
259                return "'NaN'"
260        elif val is None:
261            val = 'NULL'
262        elif isinstance(val, list):
263            q = self._quote
264            val = 'ARRAY[%s]' % ','.join(str(q(v)) for v in val)
265        elif isinstance(val, tuple):
266            q = self._quote
267            val = 'ROW(%s)' % ','.join(str(q(v)) for v in val)
268        elif Decimal is not float and isinstance(val, Decimal):
269            pass
270        elif hasattr(val, '__pg_repr__'):
271            val = val.__pg_repr__()
272        else:
273            raise InterfaceError(
274                'do not know how to handle type %s' % type(val))
275        return val
276
277    def _quoteparams(self, string, parameters):
278        """Quote parameters.
279
280        This function works for both mappings and sequences.
281        """
282        if isinstance(parameters, dict):
283            parameters = _quotedict(parameters)
284            parameters.quote = self._quote
285        else:
286            parameters = tuple(map(self._quote, parameters))
287        return string % parameters
288
289    def close(self):
290        """Close the cursor object."""
291        self._src.close()
292        self.description = None
293        self.colnames = self.coltypes = None
294        self.rowcount = -1
295        self.lastrowid = None
296
297    def execute(self, operation, parameters=None):
298        """Prepare and execute a database operation (query or command)."""
299        # The parameters may also be specified as list of tuples to e.g.
300        # insert multiple rows in a single operation, but this kind of
301        # usage is deprecated.  We make several plausibility checks because
302        # tuples can also be passed with the meaning of ROW constructors.
303        if (parameters and isinstance(parameters, list)
304                and len(parameters) > 1
305                and all(isinstance(p, tuple) for p in parameters)
306                and all(len(p) == len(parameters[0]) for p in parameters[1:])):
307            return self.executemany(operation, parameters)
308        else:
309            # not a list of tuples
310            return self.executemany(operation, [parameters])
311
312    def executemany(self, operation, seq_of_parameters):
313        """Prepare operation and execute it against a parameter sequence."""
314        if not seq_of_parameters:
315            # don't do anything without parameters
316            return
317        self.description = None
318        self.colnames = self.coltypes = None
319        self.rowcount = -1
320        # first try to execute all queries
321        rowcount = 0
322        sql = "BEGIN"
323        try:
324            if not self._dbcnx._tnx:
325                try:
326                    self._cnx.source().execute(sql)
327                except DatabaseError:
328                    raise  # database provides error message
329                except Exception as err:
330                    raise _op_error("can't start transaction")
331                self._dbcnx._tnx = True
332            for parameters in seq_of_parameters:
333                sql = operation
334                if parameters:
335                    sql = self._quoteparams(sql, parameters)
336                rows = self._src.execute(sql)
337                if rows:  # true if not DML
338                    rowcount += rows
339                else:
340                    self.rowcount = -1
341        except DatabaseError:
342            raise  # database provides error message
343        except Error as err:
344            raise _db_error(
345                "error in '%s': '%s' " % (sql, err), InterfaceError)
346        except Exception as err:
347            raise _op_error("internal error in '%s': %s" % (sql, err))
348        # then initialize result raw count and description
349        if self._src.resulttype == RESULT_DQL:
350            self.rowcount = self._src.ntuples
351            getdescr = self._type_cache.getdescr
352            description = [CursorDescription(
353                info[1], *getdescr(info[2])) for info in self._src.listinfo()]
354            self.colnames = [info[0] for info in description]
355            self.coltypes = [info[1] for info in description]
356            self.description = description
357            self.lastrowid = None
358            if self.build_row_factory:
359                self.row_factory = self.build_row_factory()
360        else:
361            self.rowcount = rowcount
362            self.lastrowid = self._src.oidstatus()
363        # return the cursor object, so you can write statements such as
364        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
365        return self
366
367    def fetchone(self):
368        """Fetch the next row of a query result set."""
369        res = self.fetchmany(1, False)
370        try:
371            return res[0]
372        except IndexError:
373            return None
374
375    def fetchall(self):
376        """Fetch all (remaining) rows of a query result."""
377        return self.fetchmany(-1, False)
378
379    def fetchmany(self, size=None, keep=False):
380        """Fetch the next set of rows of a query result.
381
382        The number of rows to fetch per call is specified by the
383        size parameter. If it is not given, the cursor's arraysize
384        determines the number of rows to be fetched. If you set
385        the keep parameter to true, this is kept as new arraysize.
386        """
387        if size is None:
388            size = self.arraysize
389        if keep:
390            self.arraysize = size
391        try:
392            result = self._src.fetch(size)
393        except DatabaseError:
394            raise
395        except Error as err:
396            raise _db_error(str(err))
397        typecast = self._type_cache.typecast
398        return [self.row_factory([typecast(typ, value)
399            for typ, value in zip(self.coltypes, row)]) for row in result]
400
401    def callproc(self, procname, parameters=None):
402        """Call a stored database procedure with the given name.
403
404        The sequence of parameters must contain one entry for each input
405        argument that the procedure expects. The result of the call is the
406        same as this input sequence; replacement of output and input/output
407        parameters in the return value is currently not supported.
408
409        The procedure may also provide a result set as output. These can be
410        requested through the standard fetch methods of the cursor.
411        """
412        n = parameters and len(parameters) or 0
413        query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s']))
414        self.execute(query, parameters)
415        return parameters
416
417    def copy_from(self, stream, table,
418            format=None, sep=None, null=None, size=None, columns=None):
419        """Copy data from an input stream to the specified table.
420
421        The input stream can be a file-like object with a read() method or
422        it can also be an iterable returning a row or multiple rows of input
423        on each iteration.
424
425        The format must be text, csv or binary. The sep option sets the
426        column separator (delimiter) used in the non binary formats.
427        The null option sets the textual representation of NULL in the input.
428
429        The size option sets the size of the buffer used when reading data
430        from file-like objects.
431
432        The copy operation can be restricted to a subset of columns. If no
433        columns are specified, all of them will be copied.
434        """
435        binary_format = format == 'binary'
436        try:
437            read = stream.read
438        except AttributeError:
439            if size:
440                raise ValueError("size must only be set for file-like objects")
441            if binary_format:
442                input_type = bytes
443                type_name = 'byte strings'
444            else:
445                input_type = basestring
446                type_name = 'strings'
447
448            if isinstance(stream, basestring):
449                if not isinstance(stream, input_type):
450                    raise ValueError("The input must be %s" % type_name)
451                if not binary_format:
452                    if isinstance(stream, str):
453                        if not stream.endswith('\n'):
454                            stream += '\n'
455                    else:
456                        if not stream.endswith(b'\n'):
457                            stream += b'\n'
458
459                def chunks():
460                    yield stream
461
462            elif isinstance(stream, Iterable):
463
464                def chunks():
465                    for chunk in stream:
466                        if not isinstance(chunk, input_type):
467                            raise ValueError(
468                                "Input stream must consist of %s" % type_name)
469                        if isinstance(chunk, str):
470                            if not chunk.endswith('\n'):
471                                chunk += '\n'
472                        else:
473                            if not chunk.endswith(b'\n'):
474                                chunk += b'\n'
475                        yield chunk
476
477            else:
478                raise TypeError("Need an input stream to copy from")
479        else:
480            if size is None:
481                size = 8192
482            elif not isinstance(size, int):
483                raise TypeError("The size option must be an integer")
484            if size > 0:
485
486                def chunks():
487                    while True:
488                        buffer = read(size)
489                        yield buffer
490                        if not buffer or len(buffer) < size:
491                            break
492
493            else:
494
495                def chunks():
496                    yield read()
497
498        if not table or not isinstance(table, basestring):
499            raise TypeError("Need a table to copy to")
500        if table.lower().startswith('select'):
501                raise ValueError("Must specify a table, not a query")
502        else:
503            table = '"%s"' % (table,)
504        operation = ['copy %s' % (table,)]
505        options = []
506        params = []
507        if format is not None:
508            if not isinstance(format, basestring):
509                raise TypeError("format option must be be a string")
510            if format not in ('text', 'csv', 'binary'):
511                raise ValueError("invalid format")
512            options.append('format %s' % (format,))
513        if sep is not None:
514            if not isinstance(sep, basestring):
515                raise TypeError("sep option must be a string")
516            if format == 'binary':
517                raise ValueError("sep is not allowed with binary format")
518            if len(sep) != 1:
519                raise ValueError("sep must be a single one-byte character")
520            options.append('delimiter %s')
521            params.append(sep)
522        if null is not None:
523            if not isinstance(null, basestring):
524                raise TypeError("null option must be a string")
525            options.append('null %s')
526            params.append(null)
527        if columns:
528            if not isinstance(columns, basestring):
529                columns = ','.join('"%s"' % (col,) for col in columns)
530            operation.append('(%s)' % (columns,))
531        operation.append("from stdin")
532        if options:
533            operation.append('(%s)' % ','.join(options))
534        operation = ' '.join(operation)
535
536        putdata = self._src.putdata
537        self.execute(operation, params)
538
539        try:
540            for chunk in chunks():
541                putdata(chunk)
542        except BaseException as error:
543            self.rowcount = -1
544            # the following call will re-raise the error
545            putdata(error)
546        else:
547            self.rowcount = putdata(None)
548
549        # return the cursor object, so you can chain operations
550        return self
551
552    def copy_to(self, stream, table,
553            format=None, sep=None, null=None, decode=None, columns=None):
554        """Copy data from the specified table to an output stream.
555
556        The output stream can be a file-like object with a write() method or
557        it can also be None, in which case the method will return a generator
558        yielding a row on each iteration.
559
560        Output will be returned as byte strings unless you set decode to true.
561
562        Note that you can also use a select query instead of the table name.
563
564        The format must be text, csv or binary. The sep option sets the
565        column separator (delimiter) used in the non binary formats.
566        The null option sets the textual representation of NULL in the output.
567
568        The copy operation can be restricted to a subset of columns. If no
569        columns are specified, all of them will be copied.
570        """
571        binary_format = format == 'binary'
572        if stream is not None:
573            try:
574                write = stream.write
575            except AttributeError:
576                raise TypeError("need an output stream to copy to")
577        if not table or not isinstance(table, basestring):
578            raise TypeError("need a table to copy to")
579        if table.lower().startswith('select'):
580            if columns:
581                raise ValueError("columns must be specified in the query")
582            table = '(%s)' % (table,)
583        else:
584            table = '"%s"' % (table,)
585        operation = ['copy %s' % (table,)]
586        options = []
587        params = []
588        if format is not None:
589            if not isinstance(format, basestring):
590                raise TypeError("format option must be a string")
591            if format not in ('text', 'csv', 'binary'):
592                raise ValueError("invalid format")
593            options.append('format %s' % (format,))
594        if sep is not None:
595            if not isinstance(sep, basestring):
596                raise TypeError("sep option must be a string")
597            if binary_format:
598                raise ValueError("sep is not allowed with binary format")
599            if len(sep) != 1:
600                raise ValueError("sep must be a single one-byte character")
601            options.append('delimiter %s')
602            params.append(sep)
603        if null is not None:
604            if not isinstance(null, basestring):
605                raise TypeError("null option must be a string")
606            options.append('null %s')
607            params.append(null)
608        if decode is None:
609            if format == 'binary':
610                decode = False
611            else:
612                decode = str is unicode
613        else:
614            if not isinstance(decode, (int, bool)):
615                raise TypeError("decode option must be a boolean")
616            if decode and binary_format:
617                raise ValueError("decode is not allowed with binary format")
618        if columns:
619            if not isinstance(columns, basestring):
620                columns = ','.join('"%s"' % (col,) for col in columns)
621            operation.append('(%s)' % (columns,))
622
623        operation.append("to stdout")
624        if options:
625            operation.append('(%s)' % ','.join(options))
626        operation = ' '.join(operation)
627
628        getdata = self._src.getdata
629        self.execute(operation, params)
630
631        def copy():
632            self.rowcount = 0
633            while True:
634                row = getdata(decode)
635                if isinstance(row, int):
636                    if self.rowcount != row:
637                        self.rowcount = row
638                    break
639                self.rowcount += 1
640                yield row
641
642        if stream is None:
643            # no input stream, return the generator
644            return copy()
645
646        # write the rows to the file-like input stream
647        for row in copy():
648            write(row)
649
650        # return the cursor object, so you can chain operations
651        return self
652
653    def __next__(self):
654        """Return the next row (support for the iteration protocol)."""
655        res = self.fetchone()
656        if res is None:
657            raise StopIteration
658        return res
659
660    # Note that since Python 2.6 the iterator protocol uses __next()__
661    # instead of next(), we keep it only for backward compatibility of pgdb.
662    next = __next__
663
664    @staticmethod
665    def nextset():
666        """Not supported."""
667        raise NotSupportedError("nextset() is not supported")
668
669    @staticmethod
670    def setinputsizes(sizes):
671        """Not supported."""
672        pass  # unsupported, but silently passed
673
674    @staticmethod
675    def setoutputsize(size, column=0):
676        """Not supported."""
677        pass  # unsupported, but silently passed
678
679    @staticmethod
680    def row_factory(row):
681        """Process rows before they are returned.
682
683        You can overwrite this statically with a custom row factory, or
684        you can build a row factory dynamically with build_row_factory().
685
686        For example, you can create a Cursor class that returns rows as
687        Python dictionaries like this:
688
689            class DictCursor(pgdb.Cursor):
690
691                def row_factory(self, row):
692                    return {desc[0]: value
693                        for desc, value in zip(self.description, row)}
694
695            cur = DictCursor(con)  # get one DictCursor instance or
696            con.cursor_type = DictCursor  # always use DictCursor instances
697        """
698        raise NotImplementedError
699
700    def build_row_factory(self):
701        """Build a row factory based on the current description.
702
703        This implementation builds a row factory for creating named tuples.
704        You can overwrite this method if you want to dynamically create
705        different row factories whenever the column description changes.
706        """
707        colnames = self.colnames
708        if colnames:
709            try:
710                try:
711                    return namedtuple('Row', colnames, rename=True)._make
712                except TypeError:  # Python 2.6 and 3.0 do not support rename
713                    colnames = [v if v.isalnum() else 'column_%d' % n
714                             for n, v in enumerate(colnames)]
715                    return namedtuple('Row', colnames)._make
716            except ValueError:  # there is still a problem with the field names
717                colnames = ['column_%d' % n for n in range(len(colnames))]
718                return namedtuple('Row', colnames)._make
719
720
721CursorDescription = namedtuple('CursorDescription',
722    ['name', 'type_code', 'display_size', 'internal_size',
723     'precision', 'scale', 'null_ok'])
724
725
726### Connection Objects
727
728class Connection(object):
729    """Connection object."""
730
731    # expose the exceptions as attributes on the connection object
732    Error = Error
733    Warning = Warning
734    InterfaceError = InterfaceError
735    DatabaseError = DatabaseError
736    InternalError = InternalError
737    OperationalError = OperationalError
738    ProgrammingError = ProgrammingError
739    IntegrityError = IntegrityError
740    DataError = DataError
741    NotSupportedError = NotSupportedError
742
743    def __init__(self, cnx):
744        """Create a database connection object."""
745        self._cnx = cnx  # connection
746        self._tnx = False  # transaction state
747        self._type_cache = TypeCache(cnx)
748        self.cursor_type = Cursor
749        try:
750            self._cnx.source()
751        except Exception:
752            raise _op_error("invalid connection")
753
754    def __enter__(self):
755        """Enter the runtime context for the connection object.
756
757        The runtime context can be used for running transactions.
758        """
759        return self
760
761    def __exit__(self, et, ev, tb):
762        """Exit the runtime context for the connection object.
763
764        This does not close the connection, but it ends a transaction.
765        """
766        if et is None and ev is None and tb is None:
767            self.commit()
768        else:
769            self.rollback()
770
771    def close(self):
772        """Close the connection object."""
773        if self._cnx:
774            if self._tnx:
775                try:
776                    self.rollback()
777                except DatabaseError:
778                    pass
779            self._cnx.close()
780            self._cnx = None
781        else:
782            raise _op_error("connection has been closed")
783
784    def commit(self):
785        """Commit any pending transaction to the database."""
786        if self._cnx:
787            if self._tnx:
788                self._tnx = False
789                try:
790                    self._cnx.source().execute("COMMIT")
791                except DatabaseError:
792                    raise
793                except Exception:
794                    raise _op_error("can't commit")
795        else:
796            raise _op_error("connection has been closed")
797
798    def rollback(self):
799        """Roll back to the start of any pending transaction."""
800        if self._cnx:
801            if self._tnx:
802                self._tnx = False
803                try:
804                    self._cnx.source().execute("ROLLBACK")
805                except DatabaseError:
806                    raise
807                except Exception:
808                    raise _op_error("can't rollback")
809        else:
810            raise _op_error("connection has been closed")
811
812    def cursor(self):
813        """Return a new cursor object using the connection."""
814        if self._cnx:
815            try:
816                return self.cursor_type(self)
817            except Exception:
818                raise _op_error("invalid connection")
819        else:
820            raise _op_error("connection has been closed")
821
822    if shortcutmethods:  # otherwise do not implement and document this
823
824        def execute(self, operation, params=None):
825            """Shortcut method to run an operation on an implicit cursor."""
826            cursor = self.cursor()
827            cursor.execute(operation, params)
828            return cursor
829
830        def executemany(self, operation, param_seq):
831            """Shortcut method to run an operation against a sequence."""
832            cursor = self.cursor()
833            cursor.executemany(operation, param_seq)
834            return cursor
835
836
837### Module Interface
838
839_connect_ = connect
840
841def connect(dsn=None,
842        user=None, password=None,
843        host=None, database=None):
844    """Connect to a database."""
845    # first get params from DSN
846    dbport = -1
847    dbhost = ""
848    dbbase = ""
849    dbuser = ""
850    dbpasswd = ""
851    dbopt = ""
852    try:
853        params = dsn.split(":")
854        dbhost = params[0]
855        dbbase = params[1]
856        dbuser = params[2]
857        dbpasswd = params[3]
858        dbopt = params[4]
859    except (AttributeError, IndexError, TypeError):
860        pass
861
862    # override if necessary
863    if user is not None:
864        dbuser = user
865    if password is not None:
866        dbpasswd = password
867    if database is not None:
868        dbbase = database
869    if host is not None:
870        try:
871            params = host.split(":")
872            dbhost = params[0]
873            dbport = int(params[1])
874        except (AttributeError, IndexError, TypeError, ValueError):
875            pass
876
877    # empty host is localhost
878    if dbhost == "":
879        dbhost = None
880    if dbuser == "":
881        dbuser = None
882
883    # open the connection
884    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
885    return Connection(cnx)
886
887
888### Types Handling
889
890class Type(frozenset):
891    """Type class for a couple of PostgreSQL data types.
892
893    PostgreSQL is object-oriented: types are dynamic.
894    We must thus use type names as internal type codes.
895    """
896
897    def __new__(cls, values):
898        if isinstance(values, basestring):
899            values = values.split()
900        return super(Type, cls).__new__(cls, values)
901
902    def __eq__(self, other):
903        if isinstance(other, basestring):
904            return other in self
905        else:
906            return super(Type, self).__eq__(other)
907
908    def __ne__(self, other):
909        if isinstance(other, basestring):
910            return other not in self
911        else:
912            return super(Type, self).__ne__(other)
913
914
915# Mandatory type objects defined by DB-API 2 specs:
916
917STRING = Type('char bpchar name text varchar')
918BINARY = Type('bytea')
919NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
920DATETIME = Type('date time timetz timestamp timestamptz datetime abstime'
921    ' interval tinterval timespan reltime')
922ROWID = Type('oid oid8')
923
924
925# Additional type objects (more specific):
926
927BOOL = Type('bool')
928SMALLINT = Type('int2')
929INTEGER = Type('int2 int4 int8 serial')
930LONG = Type('int8')
931FLOAT = Type('float4 float8')
932NUMERIC = Type('numeric')
933MONEY = Type('money')
934DATE = Type('date')
935TIME = Type('time timetz')
936TIMESTAMP = Type('timestamp timestamptz datetime abstime')
937INTERVAL = Type('interval tinterval timespan reltime')
938JSON = Type('json jsonb')
939
940
941# Mandatory type helpers defined by DB-API 2 specs:
942
943def Date(year, month, day):
944    """Construct an object holding a date value."""
945    return date(year, month, day)
946
947
948def Time(hour, minute=0, second=0, microsecond=0):
949    """Construct an object holding a time value."""
950    return time(hour, minute, second, microsecond)
951
952
953def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
954    """Construct an object holding a time stamp value."""
955    return datetime(year, month, day, hour, minute, second, microsecond)
956
957
958def DateFromTicks(ticks):
959    """Construct an object holding a date value from the given ticks value."""
960    return Date(*localtime(ticks)[:3])
961
962
963def TimeFromTicks(ticks):
964    """Construct an object holding a time value from the given ticks value."""
965    return Time(*localtime(ticks)[3:6])
966
967
968def TimestampFromTicks(ticks):
969    """Construct an object holding a time stamp from the given ticks value."""
970    return Timestamp(*localtime(ticks)[:6])
971
972
973class Binary(bytes):
974    """Construct an object capable of holding a binary (long) string value."""
975
976
977# Additional type helpers for PyGreSQL:
978
979class Json:
980    """Construct a wrapper for holding an object serializable to JSON."""
981
982    def __init__(self, obj, encode=None):
983        self.obj = obj
984        self.encode = encode or jsonencode
985
986    def __str__(self):
987        obj = self.obj
988        if isinstance(obj, basestring):
989            return obj
990        return self.encode(obj)
991
992    __pg_repr__ = __str__
993
994
995# If run as script, print some information:
996
997if __name__ == '__main__':
998    print('PyGreSQL version', version)
999    print('')
1000    print(__doc__)
Note: See TracBrowser for help on using the repository browser.