source: trunk/module/pgdb.py @ 678

Last change on this file since 678 was 678, checked in by cito, 4 years ago

More elegant code using decorators and dict comprehension

  • Property svn:keywords set to Id
File size: 20.0 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 678 2015-12-30 21:01:53Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt:tty'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that precision, scale and null_ok are not implemented.
52
53    cursor.rowcount # number of rows available in the result set
54    # Available after a call to execute.
55
56    connection.commit() # commit transaction
57
58    connection.rollback() # or rollback transaction
59
60    cursor.close() # close the cursor
61
62    connection.close() # close the connection
63
64"""
65
66from __future__ import print_function
67
68from _pg import *
69
70from datetime import date, time, datetime, timedelta
71from time import localtime
72from decimal import Decimal
73from math import isnan, isinf
74
75try:
76    long
77except NameError:  # Python >= 3.0
78    long = int
79
80try:
81    unicode
82except NameError:  # Python >= 3.0
83    unicode = str
84
85try:
86    basestring
87except NameError:  # Python >= 3.0
88    basestring = (str, bytes)
89
90set_decimal(Decimal)
91
92
93### Module Constants
94
95# compliant with DB API 2.0
96apilevel = '2.0'
97
98# module may be shared, but not connections
99threadsafety = 1
100
101# this module use extended python format codes
102paramstyle = 'pyformat'
103
104# shortcut methods are not supported by default
105# since they have been excluded from DB API 2
106# and are not recommended by the DB SIG.
107
108shortcutmethods = 0
109
110
111### Internal Types Handling
112
113def decimal_type(decimal_type=None):
114    """Get or set global type to be used for decimal values."""
115    global Decimal
116    if decimal_type is not None:
117        _cast['numeric'] = Decimal = decimal_type
118        set_decimal(Decimal)
119    return Decimal
120
121
122def _cast_bool(value):
123    return value[:1] in ('t', 'T')
124
125
126def _cast_money(value):
127    return Decimal(''.join(filter(
128        lambda v: v in '0123456789.-', value)))
129
130
131def _cast_bytea(value):
132    return unescape_bytea(value)
133
134
135def _cast_float(value):
136    try:
137        return float(value)
138    except ValueError:
139        if value == 'NaN':
140            return nan
141        elif value == 'Infinity':
142            return inf
143        elif value == '-Infinity':
144            return -inf
145        raise
146
147
148_cast = {'bool': _cast_bool, 'bytea': _cast_bytea,
149    'int2': int, 'int4': int, 'serial': int,
150    'int8': long, 'oid': long, 'oid8': long,
151    'float4': _cast_float, 'float8': _cast_float,
152    'numeric': Decimal, 'money': _cast_money}
153
154
155def _db_error(msg, cls=DatabaseError):
156    """Returns DatabaseError with empty sqlstate attribute."""
157    error = cls(msg)
158    error.sqlstate = None
159    return error
160
161
162def _op_error(msg):
163    """Returns OperationalError."""
164    return _db_error(msg, OperationalError)
165
166
167class pgdbTypeCache(dict):
168    """Cache for database types."""
169
170    def __init__(self, cnx):
171        """Initialize type cache for connection."""
172        super(pgdbTypeCache, self).__init__()
173        self._src = cnx.source()
174
175    @staticmethod
176    def typecast(typ, value):
177        """Cast value to database type."""
178        if value is None:
179            # for NULL values, no typecast is necessary
180            return None
181        cast = _cast.get(typ)
182        if cast is None:
183            # no typecast available or necessary
184            return value
185        else:
186            return cast(value)
187
188    def getdescr(self, oid):
189        """Get name of database type with given oid."""
190        try:
191            return self[oid]
192        except KeyError:
193            self._src.execute(
194                "SELECT typname, typlen "
195                "FROM pg_type WHERE oid=%s" % oid)
196            res = self._src.fetch(1)[0]
197            # The column name is omitted from the return value.
198            # It will have to be prepended by the caller.
199            res = (res[0], None, int(res[1]),
200                None, None, None)
201            self[oid] = res
202            return res
203
204
205class _quotedict(dict):
206    """Dictionary with auto quoting of its items.
207
208    The quote attribute must be set to the desired quote function.
209
210    """
211
212    def __getitem__(self, key):
213        return self.quote(super(_quotedict, self).__getitem__(key))
214
215
216### Cursor Object
217
218class pgdbCursor(object):
219    """Cursor object."""
220
221    def __init__(self, dbcnx):
222        """Create a cursor object for the database connection."""
223        self.connection = self._dbcnx = dbcnx
224        self._cnx = dbcnx._cnx
225        self._type_cache = dbcnx._type_cache
226        self._src = self._cnx.source()
227        self.description = None
228        self.rowcount = -1
229        self.arraysize = 1
230        self.lastrowid = None
231
232    def __iter__(self):
233        """Make cursors compatible to the iteration protocol."""
234        return self
235
236    def __enter__(self):
237        """Enter the runtime context for the cursor object."""
238        return self
239
240    def __exit__(self, et, ev, tb):
241        """Exit the runtime context for the cursor object."""
242        self.close()
243
244    def _quote(self, val):
245        """Quote value depending on its type."""
246        if isinstance(val, (datetime, date, time, timedelta)):
247            val = str(val)
248        if isinstance(val, basestring):
249            if isinstance(val, Binary):
250                val = self._cnx.escape_bytea(val)
251                if bytes is not str:  # Python >= 3.0
252                    val = val.decode('ascii')
253            else:
254                val = self._cnx.escape_string(val)
255            val = "'%s'" % val
256        elif isinstance(val, (int, long)):
257            pass
258        elif isinstance(val, float):
259            if isinf(val):
260                return "'-Infinity'" if val < 0 else "'Infinity'"
261            elif isnan(val):
262                return "'NaN'"
263        elif val is None:
264            val = 'NULL'
265        elif isinstance(val, (list, tuple)):
266            val = '(%s)' % ','.join(map(lambda v: str(self._quote(v)), val))
267        elif Decimal is not float and isinstance(val, Decimal):
268            pass
269        elif hasattr(val, '__pg_repr__'):
270            val = val.__pg_repr__()
271        else:
272            raise InterfaceError(
273                'do not know how to handle type %s' % type(val))
274        return val
275
276    def _quoteparams(self, string, params):
277        """Quote parameters.
278
279        This function works for both mappings and sequences.
280
281        """
282        if isinstance(params, dict):
283            params = _quotedict(params)
284            params.quote = self._quote
285        else:
286            params = tuple(map(self._quote, params))
287        return string % params
288
289    @staticmethod
290    def row_factory(row):
291        """Process rows before they are returned.
292
293        You can overwrite this with a custom row factory,
294        e.g. a dict factory:
295
296            class DictCursor(pgdb.pgdbCursor):
297
298                def row_factory(self, row):
299                    return {desc[0]:value
300                        for desc, value in zip(self.description, row)}
301
302            cur = DictCursor(con)
303
304        """
305        return row
306
307    def close(self):
308        """Close the cursor object."""
309        self._src.close()
310        self.description = None
311        self.rowcount = -1
312        self.lastrowid = None
313
314    def execute(self, operation, params=None):
315        """Prepare and execute a database operation (query or command)."""
316
317        # The parameters may also be specified as list of
318        # tuples to e.g. insert multiple rows in a single
319        # operation, but this kind of usage is deprecated:
320        if (params and isinstance(params, list)
321                and isinstance(params[0], tuple)):
322            return self.executemany(operation, params)
323        else:
324            # not a list of tuples
325            return self.executemany(operation, [params])
326
327    def executemany(self, operation, param_seq):
328        """Prepare operation and execute it against a parameter sequence."""
329        if not param_seq:
330            # don't do anything without parameters
331            return
332        self.description = None
333        self.rowcount = -1
334        # first try to execute all queries
335        totrows = 0
336        sql = "BEGIN"
337        try:
338            if not self._dbcnx._tnx:
339                try:
340                    self._cnx.source().execute(sql)
341                except DatabaseError:
342                    raise
343                except Exception:
344                    raise _op_error("can't start transaction")
345                self._dbcnx._tnx = True
346            for params in param_seq:
347                if params:
348                    sql = self._quoteparams(operation, params)
349                else:
350                    sql = operation
351                rows = self._src.execute(sql)
352                if rows:  # true if not DML
353                    totrows += rows
354                else:
355                    self.rowcount = -1
356        except DatabaseError:
357            raise
358        except Error as err:
359            raise _db_error("error in '%s': '%s' " % (sql, err))
360        except Exception as err:
361            raise _op_error("internal error in '%s': %s" % (sql, err))
362        # then initialize result raw count and description
363        if self._src.resulttype == RESULT_DQL:
364            self.rowcount = self._src.ntuples
365            getdescr = self._type_cache.getdescr
366            coltypes = self._src.listinfo()
367            self.description = [
368                typ[1:2] + getdescr(typ[2]) for typ in coltypes]
369            self.lastrowid = None
370        else:
371            self.rowcount = totrows
372            self.description = None
373            self.lastrowid = self._src.oidstatus()
374        # return the cursor object, so you can write statements such as
375        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
376        return self
377
378    def fetchone(self):
379        """Fetch the next row of a query result set."""
380        res = self.fetchmany(1, False)
381        try:
382            return res[0]
383        except IndexError:
384            return None
385
386    def fetchall(self):
387        """Fetch all (remaining) rows of a query result."""
388        return self.fetchmany(-1, False)
389
390    def fetchmany(self, size=None, keep=False):
391        """Fetch the next set of rows of a query result.
392
393        The number of rows to fetch per call is specified by the
394        size parameter. If it is not given, the cursor's arraysize
395        determines the number of rows to be fetched. If you set
396        the keep parameter to true, this is kept as new arraysize.
397
398        """
399        if size is None:
400            size = self.arraysize
401        if keep:
402            self.arraysize = size
403        try:
404            result = self._src.fetch(size)
405        except DatabaseError:
406            raise
407        except Error as err:
408            raise _db_error(str(err))
409        row_factory = self.row_factory
410        typecast = self._type_cache.typecast
411        coltypes = [desc[1] for desc in self.description]
412        return [row_factory([typecast(*args)
413            for args in zip(coltypes, row)]) for row in result]
414
415    def __next__(self):
416        """Return the next row (support for the iteration protocol)."""
417        res = self.fetchone()
418        if res is None:
419            raise StopIteration
420        return res
421
422    # Note that since Python 2.6 the iterator protocol uses __next()__
423    # instead of next(), we keep it only for backward compatibility of pgdb.
424    next = __next__
425
426    @staticmethod
427    def nextset():
428        """Not supported."""
429        raise NotSupportedError("nextset() is not supported")
430
431    @staticmethod
432    def setinputsizes(sizes):
433        """Not supported."""
434        pass  # unsupported, but silently passed
435
436    @staticmethod
437    def setoutputsize(size, column=0):
438        """Not supported."""
439        pass  # unsupported, but silently passed
440
441
442### Connection Objects
443
444class pgdbCnx(object):
445    """Connection object."""
446
447    # expose the exceptions as attributes on the connection object
448    Error = Error
449    Warning = Warning
450    InterfaceError = InterfaceError
451    DatabaseError = DatabaseError
452    InternalError = InternalError
453    OperationalError = OperationalError
454    ProgrammingError = ProgrammingError
455    IntegrityError = IntegrityError
456    DataError = DataError
457    NotSupportedError = NotSupportedError
458
459    def __init__(self, cnx):
460        """Create a database connection object."""
461        self._cnx = cnx  # connection
462        self._tnx = False  # transaction state
463        self._type_cache = pgdbTypeCache(cnx)
464        try:
465            self._cnx.source()
466        except Exception:
467            raise _op_error("invalid connection")
468
469    def __enter__(self):
470        """Enter the runtime context for the connection object.
471
472        The runtime context can be used for running transactions.
473
474        """
475        return self
476
477    def __exit__(self, et, ev, tb):
478        """Exit the runtime context for the connection object.
479
480        This does not close the connection, but it ends a transaction.
481
482        """
483        if et is None and ev is None and tb is None:
484            self.commit()
485        else:
486            self.rollback()
487
488    def close(self):
489        """Close the connection object."""
490        if self._cnx:
491            if self._tnx:
492                try:
493                    self.rollback()
494                except DatabaseError:
495                    pass
496            self._cnx.close()
497            self._cnx = None
498        else:
499            raise _op_error("connection has been closed")
500
501    def commit(self):
502        """Commit any pending transaction to the database."""
503        if self._cnx:
504            if self._tnx:
505                self._tnx = False
506                try:
507                    self._cnx.source().execute("COMMIT")
508                except DatabaseError:
509                    raise
510                except Exception:
511                    raise _op_error("can't commit")
512        else:
513            raise _op_error("connection has been closed")
514
515    def rollback(self):
516        """Roll back to the start of any pending transaction."""
517        if self._cnx:
518            if self._tnx:
519                self._tnx = False
520                try:
521                    self._cnx.source().execute("ROLLBACK")
522                except DatabaseError:
523                    raise
524                except Exception:
525                    raise _op_error("can't rollback")
526        else:
527            raise _op_error("connection has been closed")
528
529    def cursor(self):
530        """Return a new cursor object using the connection."""
531        if self._cnx:
532            try:
533                return pgdbCursor(self)
534            except Exception:
535                raise _op_error("invalid connection")
536        else:
537            raise _op_error("connection has been closed")
538
539    if shortcutmethods:  # otherwise do not implement and document this
540
541        def execute(self, operation, params=None):
542            """Shortcut method to run an operation on an implicit cursor."""
543            cursor = self.cursor()
544            cursor.execute(operation, params)
545            return cursor
546
547        def executemany(self, operation, param_seq):
548            """Shortcut method to run an operation against a sequence."""
549            cursor = self.cursor()
550            cursor.executemany(operation, param_seq)
551            return cursor
552
553
554### Module Interface
555
556_connect_ = connect
557
558def connect(dsn=None,
559        user=None, password=None,
560        host=None, database=None):
561    """Connects to a database."""
562    # first get params from DSN
563    dbport = -1
564    dbhost = ""
565    dbbase = ""
566    dbuser = ""
567    dbpasswd = ""
568    dbopt = ""
569    dbtty = ""
570    try:
571        params = dsn.split(":")
572        dbhost = params[0]
573        dbbase = params[1]
574        dbuser = params[2]
575        dbpasswd = params[3]
576        dbopt = params[4]
577        dbtty = params[5]
578    except (AttributeError, IndexError, TypeError):
579        pass
580
581    # override if necessary
582    if user is not None:
583        dbuser = user
584    if password is not None:
585        dbpasswd = password
586    if database is not None:
587        dbbase = database
588    if host is not None:
589        try:
590            params = host.split(":")
591            dbhost = params[0]
592            dbport = int(params[1])
593        except (AttributeError, IndexError, TypeError, ValueError):
594            pass
595
596    # empty host is localhost
597    if dbhost == "":
598        dbhost = None
599    if dbuser == "":
600        dbuser = None
601
602    # open the connection
603    cnx = _connect_(dbbase, dbhost, dbport, dbopt,
604        dbtty, dbuser, dbpasswd)
605    return pgdbCnx(cnx)
606
607
608### Types Handling
609
610class pgdbType(frozenset):
611    """Type class for a couple of PostgreSQL data types.
612
613    PostgreSQL is object-oriented: types are dynamic.
614    We must thus use type names as internal type codes.
615
616    """
617
618    def __new__(cls, values):
619        if isinstance(values, basestring):
620            values = values.split()
621        return super(pgdbType, cls).__new__(cls, values)
622
623    def __eq__(self, other):
624        if isinstance(other, basestring):
625            return other in self
626        else:
627            return super(pgdbType, self).__eq__(other)
628
629    def __ne__(self, other):
630        if isinstance(other, basestring):
631            return other not in self
632        else:
633            return super(pgdbType, self).__ne__(other)
634
635
636# Mandatory type objects defined by DB-API 2 specs:
637
638STRING = pgdbType('char bpchar name text varchar')
639BINARY = pgdbType('bytea')
640NUMBER = pgdbType('int2 int4 serial int8 float4 float8 numeric money')
641DATETIME = pgdbType('date time timetz timestamp timestamptz datetime abstime'
642    ' interval tinterval timespan reltime')
643ROWID = pgdbType('oid oid8')
644
645
646# Additional type objects (more specific):
647
648BOOL = pgdbType('bool')
649SMALLINT = pgdbType('int2')
650INTEGER = pgdbType('int2 int4 int8 serial')
651LONG = pgdbType('int8')
652FLOAT = pgdbType('float4 float8')
653NUMERIC = pgdbType('numeric')
654MONEY = pgdbType('money')
655DATE = pgdbType('date')
656TIME = pgdbType('time timetz')
657TIMESTAMP = pgdbType('timestamp timestamptz datetime abstime')
658INTERVAL = pgdbType('interval tinterval timespan reltime')
659
660
661# Mandatory type helpers defined by DB-API 2 specs:
662
663def Date(year, month, day):
664    """Construct an object holding a date value."""
665    return date(year, month, day)
666
667
668def Time(hour, minute=0, second=0, microsecond=0):
669    """Construct an object holding a time value."""
670    return time(hour, minute, second, microsecond)
671
672
673def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
674    """construct an object holding a time stamp value."""
675    return datetime(year, month, day, hour, minute, second, microsecond)
676
677
678def DateFromTicks(ticks):
679    """Construct an object holding a date value from the given ticks value."""
680    return Date(*localtime(ticks)[:3])
681
682
683def TimeFromTicks(ticks):
684    """construct an object holding a time value from the given ticks value."""
685    return Time(*localtime(ticks)[3:6])
686
687
688def TimestampFromTicks(ticks):
689    """construct an object holding a time stamp from the given ticks value."""
690    return Timestamp(*localtime(ticks)[:6])
691
692
693class Binary(bytes):
694    """construct an object capable of holding a binary (long) string value."""
695
696
697# If run as script, print some information:
698
699if __name__ == '__main__':
700    print('PyGreSQL version', version)
701    print('')
702    print(__doc__)
Note: See TracBrowser for help on using the repository browser.