source: trunk/module/pgdb.py @ 681

Last change on this file since 681 was 681, checked in by cito, 4 years ago

Let cursor.description return named tuples

  • Property svn:keywords set to Id
File size: 20.1 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 681 2015-12-30 22:39:53Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64
65"""
66
67from __future__ import print_function
68
69from _pg import *
70
71from datetime import date, time, datetime, timedelta
72from time import localtime
73from decimal import Decimal
74from math import isnan, isinf
75from collections import namedtuple
76
77try:
78    long
79except NameError:  # Python >= 3.0
80    long = int
81
82try:
83    unicode
84except NameError:  # Python >= 3.0
85    unicode = str
86
87try:
88    basestring
89except NameError:  # Python >= 3.0
90    basestring = (str, bytes)
91
92set_decimal(Decimal)
93
94
95### Module Constants
96
97# compliant with DB API 2.0
98apilevel = '2.0'
99
100# module may be shared, but not connections
101threadsafety = 1
102
103# this module use extended python format codes
104paramstyle = 'pyformat'
105
106# shortcut methods are not supported by default
107# since they have been excluded from DB API 2
108# and are not recommended by the DB SIG.
109
110shortcutmethods = 0
111
112
113### Internal Types Handling
114
115def decimal_type(decimal_type=None):
116    """Get or set global type to be used for decimal values."""
117    global Decimal
118    if decimal_type is not None:
119        _cast['numeric'] = Decimal = decimal_type
120        set_decimal(Decimal)
121    return Decimal
122
123
124def _cast_bool(value):
125    return value[:1] in ('t', 'T')
126
127
128def _cast_money(value):
129    return Decimal(''.join(filter(
130        lambda v: v in '0123456789.-', value)))
131
132
133def _cast_bytea(value):
134    return unescape_bytea(value)
135
136
137def _cast_float(value):
138    try:
139        return float(value)
140    except ValueError:
141        if value == 'NaN':
142            return nan
143        elif value == 'Infinity':
144            return inf
145        elif value == '-Infinity':
146            return -inf
147        raise
148
149
150_cast = {'bool': _cast_bool, 'bytea': _cast_bytea,
151    'int2': int, 'int4': int, 'serial': int,
152    'int8': long, 'oid': long, 'oid8': long,
153    'float4': _cast_float, 'float8': _cast_float,
154    'numeric': Decimal, 'money': _cast_money}
155
156
157def _db_error(msg, cls=DatabaseError):
158    """Returns DatabaseError with empty sqlstate attribute."""
159    error = cls(msg)
160    error.sqlstate = None
161    return error
162
163
164def _op_error(msg):
165    """Returns OperationalError."""
166    return _db_error(msg, OperationalError)
167
168
169class TypeCache(dict):
170    """Cache for database types."""
171
172    def __init__(self, cnx):
173        """Initialize type cache for connection."""
174        super(TypeCache, self).__init__()
175        self._src = cnx.source()
176
177    @staticmethod
178    def typecast(typ, value):
179        """Cast value to database type."""
180        if value is None:
181            # for NULL values, no typecast is necessary
182            return None
183        cast = _cast.get(typ)
184        if cast is None:
185            # no typecast available or necessary
186            return value
187        else:
188            return cast(value)
189
190    def getdescr(self, oid):
191        """Get name of database type with given oid."""
192        try:
193            return self[oid]
194        except KeyError:
195            self._src.execute(
196                "SELECT typname, typlen "
197                "FROM pg_type WHERE oid=%s" % oid)
198            res = self._src.fetch(1)[0]
199            # The column name is omitted from the return value.
200            # It will have to be prepended by the caller.
201            res = (res[0], None, int(res[1]), None, None, None)
202            self[oid] = res
203            return res
204
205
206class _quotedict(dict):
207    """Dictionary with auto quoting of its items.
208
209    The quote attribute must be set to the desired quote function.
210
211    """
212
213    def __getitem__(self, key):
214        return self.quote(super(_quotedict, self).__getitem__(key))
215
216
217### Cursor Object
218
219class Cursor(object):
220    """Cursor object."""
221
222    def __init__(self, dbcnx):
223        """Create a cursor object for the database connection."""
224        self.connection = self._dbcnx = dbcnx
225        self._cnx = dbcnx._cnx
226        self._type_cache = dbcnx._type_cache
227        self._src = self._cnx.source()
228        self.description = None
229        self.rowcount = -1
230        self.arraysize = 1
231        self.lastrowid = None
232
233    def __iter__(self):
234        """Make cursors compatible to the iteration protocol."""
235        return self
236
237    def __enter__(self):
238        """Enter the runtime context for the cursor object."""
239        return self
240
241    def __exit__(self, et, ev, tb):
242        """Exit the runtime context for the cursor object."""
243        self.close()
244
245    def _quote(self, val):
246        """Quote value depending on its type."""
247        if isinstance(val, (datetime, date, time, timedelta)):
248            val = str(val)
249        if isinstance(val, basestring):
250            if isinstance(val, Binary):
251                val = self._cnx.escape_bytea(val)
252                if bytes is not str:  # Python >= 3.0
253                    val = val.decode('ascii')
254            else:
255                val = self._cnx.escape_string(val)
256            val = "'%s'" % val
257        elif isinstance(val, (int, long)):
258            pass
259        elif isinstance(val, float):
260            if isinf(val):
261                return "'-Infinity'" if val < 0 else "'Infinity'"
262            elif isnan(val):
263                return "'NaN'"
264        elif val is None:
265            val = 'NULL'
266        elif isinstance(val, (list, tuple)):
267            val = '(%s)' % ','.join(map(lambda v: str(self._quote(v)), val))
268        elif Decimal is not float and isinstance(val, Decimal):
269            pass
270        elif hasattr(val, '__pg_repr__'):
271            val = val.__pg_repr__()
272        else:
273            raise InterfaceError(
274                'do not know how to handle type %s' % type(val))
275        return val
276
277    def _quoteparams(self, string, params):
278        """Quote parameters.
279
280        This function works for both mappings and sequences.
281
282        """
283        if isinstance(params, dict):
284            params = _quotedict(params)
285            params.quote = self._quote
286        else:
287            params = tuple(map(self._quote, params))
288        return string % params
289
290    @staticmethod
291    def row_factory(row):
292        """Process rows before they are returned.
293
294        You can overwrite this with a custom row factory,
295        e.g. a dict factory:
296
297            class DictCursor(pgdb.Cursor):
298
299                def row_factory(self, row):
300                    return {desc[0]:value
301                        for desc, value in zip(self.description, row)}
302
303            cur = DictCursor(con)
304
305        """
306        return row
307
308    def close(self):
309        """Close the cursor object."""
310        self._src.close()
311        self.description = None
312        self.rowcount = -1
313        self.lastrowid = None
314
315    def execute(self, operation, params=None):
316        """Prepare and execute a database operation (query or command)."""
317
318        # The parameters may also be specified as list of
319        # tuples to e.g. insert multiple rows in a single
320        # operation, but this kind of usage is deprecated:
321        if (params and isinstance(params, list)
322                and isinstance(params[0], tuple)):
323            return self.executemany(operation, params)
324        else:
325            # not a list of tuples
326            return self.executemany(operation, [params])
327
328    def executemany(self, operation, param_seq):
329        """Prepare operation and execute it against a parameter sequence."""
330        if not param_seq:
331            # don't do anything without parameters
332            return
333        self.description = None
334        self.rowcount = -1
335        # first try to execute all queries
336        totrows = 0
337        sql = "BEGIN"
338        try:
339            if not self._dbcnx._tnx:
340                try:
341                    self._cnx.source().execute(sql)
342                except DatabaseError:
343                    raise
344                except Exception:
345                    raise _op_error("can't start transaction")
346                self._dbcnx._tnx = True
347            for params in param_seq:
348                if params:
349                    sql = self._quoteparams(operation, params)
350                else:
351                    sql = operation
352                rows = self._src.execute(sql)
353                if rows:  # true if not DML
354                    totrows += rows
355                else:
356                    self.rowcount = -1
357        except DatabaseError:
358            raise
359        except Error as err:
360            raise _db_error("error in '%s': '%s' " % (sql, err))
361        except Exception as err:
362            raise _op_error("internal error in '%s': %s" % (sql, err))
363        # then initialize result raw count and description
364        if self._src.resulttype == RESULT_DQL:
365            self.rowcount = self._src.ntuples
366            getdescr = self._type_cache.getdescr
367            coltypes = self._src.listinfo()
368            self.description = [CursorDescription(
369                typ[1], *getdescr(typ[2])) for typ in coltypes]
370            self.lastrowid = None
371        else:
372            self.rowcount = totrows
373            self.description = None
374            self.lastrowid = self._src.oidstatus()
375        # return the cursor object, so you can write statements such as
376        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
377        return self
378
379    def fetchone(self):
380        """Fetch the next row of a query result set."""
381        res = self.fetchmany(1, False)
382        try:
383            return res[0]
384        except IndexError:
385            return None
386
387    def fetchall(self):
388        """Fetch all (remaining) rows of a query result."""
389        return self.fetchmany(-1, False)
390
391    def fetchmany(self, size=None, keep=False):
392        """Fetch the next set of rows of a query result.
393
394        The number of rows to fetch per call is specified by the
395        size parameter. If it is not given, the cursor's arraysize
396        determines the number of rows to be fetched. If you set
397        the keep parameter to true, this is kept as new arraysize.
398
399        """
400        if size is None:
401            size = self.arraysize
402        if keep:
403            self.arraysize = size
404        try:
405            result = self._src.fetch(size)
406        except DatabaseError:
407            raise
408        except Error as err:
409            raise _db_error(str(err))
410        row_factory = self.row_factory
411        typecast = self._type_cache.typecast
412        coltypes = [desc[1] for desc in self.description]
413        return [row_factory([typecast(*args)
414            for args in zip(coltypes, row)]) for row in result]
415
416    def __next__(self):
417        """Return the next row (support for the iteration protocol)."""
418        res = self.fetchone()
419        if res is None:
420            raise StopIteration
421        return res
422
423    # Note that since Python 2.6 the iterator protocol uses __next()__
424    # instead of next(), we keep it only for backward compatibility of pgdb.
425    next = __next__
426
427    @staticmethod
428    def nextset():
429        """Not supported."""
430        raise NotSupportedError("nextset() is not supported")
431
432    @staticmethod
433    def setinputsizes(sizes):
434        """Not supported."""
435        pass  # unsupported, but silently passed
436
437    @staticmethod
438    def setoutputsize(size, column=0):
439        """Not supported."""
440        pass  # unsupported, but silently passed
441
442
443CursorDescription = namedtuple('CursorDescription',
444    ['name', 'type_code', 'display_size', 'internal_size',
445     'precision', 'scale', 'null_ok'])
446
447
448### Connection Objects
449
450class Connection(object):
451    """Connection object."""
452
453    # expose the exceptions as attributes on the connection object
454    Error = Error
455    Warning = Warning
456    InterfaceError = InterfaceError
457    DatabaseError = DatabaseError
458    InternalError = InternalError
459    OperationalError = OperationalError
460    ProgrammingError = ProgrammingError
461    IntegrityError = IntegrityError
462    DataError = DataError
463    NotSupportedError = NotSupportedError
464
465    def __init__(self, cnx):
466        """Create a database connection object."""
467        self._cnx = cnx  # connection
468        self._tnx = False  # transaction state
469        self._type_cache = TypeCache(cnx)
470        try:
471            self._cnx.source()
472        except Exception:
473            raise _op_error("invalid connection")
474
475    def __enter__(self):
476        """Enter the runtime context for the connection object.
477
478        The runtime context can be used for running transactions.
479
480        """
481        return self
482
483    def __exit__(self, et, ev, tb):
484        """Exit the runtime context for the connection object.
485
486        This does not close the connection, but it ends a transaction.
487
488        """
489        if et is None and ev is None and tb is None:
490            self.commit()
491        else:
492            self.rollback()
493
494    def close(self):
495        """Close the connection object."""
496        if self._cnx:
497            if self._tnx:
498                try:
499                    self.rollback()
500                except DatabaseError:
501                    pass
502            self._cnx.close()
503            self._cnx = None
504        else:
505            raise _op_error("connection has been closed")
506
507    def commit(self):
508        """Commit any pending transaction to the database."""
509        if self._cnx:
510            if self._tnx:
511                self._tnx = False
512                try:
513                    self._cnx.source().execute("COMMIT")
514                except DatabaseError:
515                    raise
516                except Exception:
517                    raise _op_error("can't commit")
518        else:
519            raise _op_error("connection has been closed")
520
521    def rollback(self):
522        """Roll back to the start of any pending transaction."""
523        if self._cnx:
524            if self._tnx:
525                self._tnx = False
526                try:
527                    self._cnx.source().execute("ROLLBACK")
528                except DatabaseError:
529                    raise
530                except Exception:
531                    raise _op_error("can't rollback")
532        else:
533            raise _op_error("connection has been closed")
534
535    def cursor(self):
536        """Return a new cursor object using the connection."""
537        if self._cnx:
538            try:
539                return Cursor(self)
540            except Exception:
541                raise _op_error("invalid connection")
542        else:
543            raise _op_error("connection has been closed")
544
545    if shortcutmethods:  # otherwise do not implement and document this
546
547        def execute(self, operation, params=None):
548            """Shortcut method to run an operation on an implicit cursor."""
549            cursor = self.cursor()
550            cursor.execute(operation, params)
551            return cursor
552
553        def executemany(self, operation, param_seq):
554            """Shortcut method to run an operation against a sequence."""
555            cursor = self.cursor()
556            cursor.executemany(operation, param_seq)
557            return cursor
558
559
560### Module Interface
561
562_connect_ = connect
563
564def connect(dsn=None,
565        user=None, password=None,
566        host=None, database=None):
567    """Connects to a database."""
568    # first get params from DSN
569    dbport = -1
570    dbhost = ""
571    dbbase = ""
572    dbuser = ""
573    dbpasswd = ""
574    dbopt = ""
575    try:
576        params = dsn.split(":")
577        dbhost = params[0]
578        dbbase = params[1]
579        dbuser = params[2]
580        dbpasswd = params[3]
581        dbopt = params[4]
582    except (AttributeError, IndexError, TypeError):
583        pass
584
585    # override if necessary
586    if user is not None:
587        dbuser = user
588    if password is not None:
589        dbpasswd = password
590    if database is not None:
591        dbbase = database
592    if host is not None:
593        try:
594            params = host.split(":")
595            dbhost = params[0]
596            dbport = int(params[1])
597        except (AttributeError, IndexError, TypeError, ValueError):
598            pass
599
600    # empty host is localhost
601    if dbhost == "":
602        dbhost = None
603    if dbuser == "":
604        dbuser = None
605
606    # open the connection
607    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
608    return Connection(cnx)
609
610
611### Types Handling
612
613class Type(frozenset):
614    """Type class for a couple of PostgreSQL data types.
615
616    PostgreSQL is object-oriented: types are dynamic.
617    We must thus use type names as internal type codes.
618
619    """
620
621    def __new__(cls, values):
622        if isinstance(values, basestring):
623            values = values.split()
624        return super(Type, cls).__new__(cls, values)
625
626    def __eq__(self, other):
627        if isinstance(other, basestring):
628            return other in self
629        else:
630            return super(Type, self).__eq__(other)
631
632    def __ne__(self, other):
633        if isinstance(other, basestring):
634            return other not in self
635        else:
636            return super(Type, self).__ne__(other)
637
638
639# Mandatory type objects defined by DB-API 2 specs:
640
641STRING = Type('char bpchar name text varchar')
642BINARY = Type('bytea')
643NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
644DATETIME = Type('date time timetz timestamp timestamptz datetime abstime'
645    ' interval tinterval timespan reltime')
646ROWID = Type('oid oid8')
647
648
649# Additional type objects (more specific):
650
651BOOL = Type('bool')
652SMALLINT = Type('int2')
653INTEGER = Type('int2 int4 int8 serial')
654LONG = Type('int8')
655FLOAT = Type('float4 float8')
656NUMERIC = Type('numeric')
657MONEY = Type('money')
658DATE = Type('date')
659TIME = Type('time timetz')
660TIMESTAMP = Type('timestamp timestamptz datetime abstime')
661INTERVAL = Type('interval tinterval timespan reltime')
662
663
664# Mandatory type helpers defined by DB-API 2 specs:
665
666def Date(year, month, day):
667    """Construct an object holding a date value."""
668    return date(year, month, day)
669
670
671def Time(hour, minute=0, second=0, microsecond=0):
672    """Construct an object holding a time value."""
673    return time(hour, minute, second, microsecond)
674
675
676def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
677    """construct an object holding a time stamp value."""
678    return datetime(year, month, day, hour, minute, second, microsecond)
679
680
681def DateFromTicks(ticks):
682    """Construct an object holding a date value from the given ticks value."""
683    return Date(*localtime(ticks)[:3])
684
685
686def TimeFromTicks(ticks):
687    """construct an object holding a time value from the given ticks value."""
688    return Time(*localtime(ticks)[3:6])
689
690
691def TimestampFromTicks(ticks):
692    """construct an object holding a time stamp from the given ticks value."""
693    return Timestamp(*localtime(ticks)[:6])
694
695
696class Binary(bytes):
697    """construct an object capable of holding a binary (long) string value."""
698
699
700# If run as script, print some information:
701
702if __name__ == '__main__':
703    print('PyGreSQL version', version)
704    print('')
705    print(__doc__)
Note: See TracBrowser for help on using the repository browser.