source: trunk/module/pgdb.py @ 683

Last change on this file since 683 was 683, checked in by cito, 4 years ago

Return rows as named tuples in pgdb

By default, we now return result rows as named tuples in pgdb.

Note that named tuples can be accessed like normal lists and tuples,
and easily converted to these. They can also be easily converted to
(ordered) dictionaries by calling row._asdict(). Therefore the need
for alternative Cursor types with different row types has been greatly
reduced, so I have simplified the implementation in the last revision
by removing the added Cursor classes and cursor() methods again,
leaving only the old row_factory method for customizing the returned
row types. I complemented this with a new build_row_factory method,
because different named tuple classes must be created for different
result sets, so a static row_factory method is not so appropriate.

Tests and documentation for these changes are included.

  • Property svn:keywords set to Id
File size: 22.2 KB
Line 
1#! /usr/bin/python
2#
3# pgdb.py
4#
5# Written by D'Arcy J.M. Cain
6#
7# $Id: pgdb.py 683 2016-01-01 23:11:01Z cito $
8#
9
10"""pgdb - DB-API 2.0 compliant module for PygreSQL.
11
12(c) 1999, Pascal Andre <andre@via.ecp.fr>.
13See package documentation for further information on copyright.
14
15Inline documentation is sparse.
16See DB-API 2.0 specification for usage information:
17http://www.python.org/peps/pep-0249.html
18
19Basic usage:
20
21    pgdb.connect(connect_string) # open a connection
22    # connect_string = 'host:database:user:password:opt'
23    # All parts are optional. You may also pass host through
24    # password as keyword arguments. To pass a port,
25    # pass it in the host keyword parameter:
26    connection = pgdb.connect(host='localhost:5432')
27
28    cursor = connection.cursor() # open a cursor
29
30    cursor.execute(query[, params])
31    # Execute a query, binding params (a dictionary) if they are
32    # passed. The binding syntax is the same as the % operator
33    # for dictionaries, and no quoting is done.
34
35    cursor.executemany(query, list of params)
36    # Execute a query many times, binding each param dictionary
37    # from the list.
38
39    cursor.fetchone() # fetch one row, [value, value, ...]
40
41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
42
43    cursor.fetchmany([size])
44    # returns size or cursor.arraysize number of rows,
45    # [[value, value, ...], ...] from result set.
46    # Default cursor.arraysize is 1.
47
48    cursor.description # returns information about the columns
49    #   [(column_name, type_name, display_size,
50    #           internal_size, precision, scale, null_ok), ...]
51    # Note that display_size, precision, scale and null_ok
52    # are not implemented.
53
54    cursor.rowcount # number of rows available in the result set
55    # Available after a call to execute.
56
57    connection.commit() # commit transaction
58
59    connection.rollback() # or rollback transaction
60
61    cursor.close() # close the cursor
62
63    connection.close() # close the connection
64
65"""
66
67from __future__ import print_function
68
69from _pg import *
70
71from datetime import date, time, datetime, timedelta
72from time import localtime
73from decimal import Decimal
74from math import isnan, isinf
75from collections import namedtuple
76
77try:
78    long
79except NameError:  # Python >= 3.0
80    long = int
81
82try:
83    unicode
84except NameError:  # Python >= 3.0
85    unicode = str
86
87try:
88    basestring
89except NameError:  # Python >= 3.0
90    basestring = (str, bytes)
91
92try:
93    from collections import OrderedDict
94except ImportError:  # Python 2.6 or 3.0
95    try:
96        from ordereddict import OrderedDict
97    except Exception:
98        def OrderedDict(*args):
99            raise NotSupportedError('OrderedDict is not supported')
100
101set_decimal(Decimal)
102
103
104### Module Constants
105
106# compliant with DB API 2.0
107apilevel = '2.0'
108
109# module may be shared, but not connections
110threadsafety = 1
111
112# this module use extended python format codes
113paramstyle = 'pyformat'
114
115# shortcut methods are not supported by default
116# since they have been excluded from DB API 2
117# and are not recommended by the DB SIG.
118
119shortcutmethods = 0
120
121
122### Internal Types Handling
123
124def decimal_type(decimal_type=None):
125    """Get or set global type to be used for decimal values."""
126    global Decimal
127    if decimal_type is not None:
128        _cast['numeric'] = Decimal = decimal_type
129        set_decimal(Decimal)
130    return Decimal
131
132
133def _cast_bool(value):
134    return value[:1] in ('t', 'T')
135
136
137def _cast_money(value):
138    return Decimal(''.join(filter(
139        lambda v: v in '0123456789.-', value)))
140
141
142def _cast_bytea(value):
143    return unescape_bytea(value)
144
145
146def _cast_float(value):
147    try:
148        return float(value)
149    except ValueError:
150        if value == 'NaN':
151            return nan
152        elif value == 'Infinity':
153            return inf
154        elif value == '-Infinity':
155            return -inf
156        raise
157
158
159_cast = {'bool': _cast_bool, 'bytea': _cast_bytea,
160    'int2': int, 'int4': int, 'serial': int,
161    'int8': long, 'oid': long, 'oid8': long,
162    'float4': _cast_float, 'float8': _cast_float,
163    'numeric': Decimal, 'money': _cast_money}
164
165
166def _db_error(msg, cls=DatabaseError):
167    """Returns DatabaseError with empty sqlstate attribute."""
168    error = cls(msg)
169    error.sqlstate = None
170    return error
171
172
173def _op_error(msg):
174    """Returns OperationalError."""
175    return _db_error(msg, OperationalError)
176
177
178class TypeCache(dict):
179    """Cache for database types."""
180
181    def __init__(self, cnx):
182        """Initialize type cache for connection."""
183        super(TypeCache, self).__init__()
184        self._src = cnx.source()
185
186    @staticmethod
187    def typecast(typ, value):
188        """Cast value to database type."""
189        if value is None:
190            # for NULL values, no typecast is necessary
191            return None
192        cast = _cast.get(typ)
193        if cast is None:
194            # no typecast available or necessary
195            return value
196        else:
197            return cast(value)
198
199    def getdescr(self, oid):
200        """Get name of database type with given oid."""
201        try:
202            return self[oid]
203        except KeyError:
204            self._src.execute(
205                "SELECT typname, typlen "
206                "FROM pg_type WHERE oid=%s" % oid)
207            res = self._src.fetch(1)[0]
208            # The column name is omitted from the return value.
209            # It will have to be prepended by the caller.
210            res = (res[0], None, int(res[1]), None, None, None)
211            self[oid] = res
212            return res
213
214
215class _quotedict(dict):
216    """Dictionary with auto quoting of its items.
217
218    The quote attribute must be set to the desired quote function.
219
220    """
221
222    def __getitem__(self, key):
223        return self.quote(super(_quotedict, self).__getitem__(key))
224
225
226### Cursor Object
227
228class Cursor(object):
229    """Cursor object."""
230
231    def __init__(self, dbcnx):
232        """Create a cursor object for the database connection."""
233        self.connection = self._dbcnx = dbcnx
234        self._cnx = dbcnx._cnx
235        self._type_cache = dbcnx._type_cache
236        self._src = self._cnx.source()
237        # the official attribute for describing the result columns
238        self.description = None
239        # unofficial attributes for convenience and performance
240        self.colnames = self.coltypes = None
241        if self.row_factory is Cursor.row_factory:
242            # the row factory needs to be determined dynamically
243            self.row_factory = None
244        else:
245            self.build_row_factory = None
246        self.rowcount = -1
247        self.arraysize = 1
248        self.lastrowid = None
249
250    def __iter__(self):
251        """Make cursors compatible to the iteration protocol."""
252        return self
253
254    def __enter__(self):
255        """Enter the runtime context for the cursor object."""
256        return self
257
258    def __exit__(self, et, ev, tb):
259        """Exit the runtime context for the cursor object."""
260        self.close()
261
262    def _quote(self, val):
263        """Quote value depending on its type."""
264        if isinstance(val, (datetime, date, time, timedelta)):
265            val = str(val)
266        if isinstance(val, basestring):
267            if isinstance(val, Binary):
268                val = self._cnx.escape_bytea(val)
269                if bytes is not str:  # Python >= 3.0
270                    val = val.decode('ascii')
271            else:
272                val = self._cnx.escape_string(val)
273            val = "'%s'" % val
274        elif isinstance(val, (int, long)):
275            pass
276        elif isinstance(val, float):
277            if isinf(val):
278                return "'-Infinity'" if val < 0 else "'Infinity'"
279            elif isnan(val):
280                return "'NaN'"
281        elif val is None:
282            val = 'NULL'
283        elif isinstance(val, (list, tuple)):
284            val = '(%s)' % ','.join(map(lambda v: str(self._quote(v)), val))
285        elif Decimal is not float and isinstance(val, Decimal):
286            pass
287        elif hasattr(val, '__pg_repr__'):
288            val = val.__pg_repr__()
289        else:
290            raise InterfaceError(
291                'do not know how to handle type %s' % type(val))
292        return val
293
294    def _quoteparams(self, string, params):
295        """Quote parameters.
296
297        This function works for both mappings and sequences.
298
299        """
300        if isinstance(params, dict):
301            params = _quotedict(params)
302            params.quote = self._quote
303        else:
304            params = tuple(map(self._quote, params))
305        return string % params
306
307    def close(self):
308        """Close the cursor object."""
309        self._src.close()
310        self.description = None
311        self.colnames = self.coltypes = None
312        self.rowcount = -1
313        self.lastrowid = None
314
315    def execute(self, operation, params=None):
316        """Prepare and execute a database operation (query or command)."""
317
318        # The parameters may also be specified as list of
319        # tuples to e.g. insert multiple rows in a single
320        # operation, but this kind of usage is deprecated:
321        if (params and isinstance(params, list)
322                and isinstance(params[0], tuple)):
323            return self.executemany(operation, params)
324        else:
325            # not a list of tuples
326            return self.executemany(operation, [params])
327
328    def executemany(self, operation, param_seq):
329        """Prepare operation and execute it against a parameter sequence."""
330        if not param_seq:
331            # don't do anything without parameters
332            return
333        self.description = None
334        self.colnames = self.coltypes = None
335        self.rowcount = -1
336        # first try to execute all queries
337        rowcount = 0
338        sql = "BEGIN"
339        try:
340            if not self._dbcnx._tnx:
341                try:
342                    self._cnx.source().execute(sql)
343                except DatabaseError:
344                    raise
345                except Exception:
346                    raise _op_error("can't start transaction")
347                self._dbcnx._tnx = True
348            for params in param_seq:
349                if params:
350                    sql = self._quoteparams(operation, params)
351                else:
352                    sql = operation
353                rows = self._src.execute(sql)
354                if rows:  # true if not DML
355                    rowcount += rows
356                else:
357                    self.rowcount = -1
358        except DatabaseError:
359            raise
360        except Error as err:
361            raise _db_error("error in '%s': '%s' " % (sql, err))
362        except Exception as err:
363            raise _op_error("internal error in '%s': %s" % (sql, err))
364        # then initialize result raw count and description
365        if self._src.resulttype == RESULT_DQL:
366            self.rowcount = self._src.ntuples
367            getdescr = self._type_cache.getdescr
368            description = [CursorDescription(
369                info[1], *getdescr(info[2])) for info in self._src.listinfo()]
370            self.colnames = [info[0] for info in description]
371            self.coltypes = [info[1] for info in description]
372            self.description = description
373            self.lastrowid = None
374            if self.build_row_factory:
375                self.row_factory = self.build_row_factory()
376        else:
377            self.rowcount = rowcount
378            self.lastrowid = self._src.oidstatus()
379        # return the cursor object, so you can write statements such as
380        # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)"
381        return self
382
383    def fetchone(self):
384        """Fetch the next row of a query result set."""
385        res = self.fetchmany(1, False)
386        try:
387            return res[0]
388        except IndexError:
389            return None
390
391    def fetchall(self):
392        """Fetch all (remaining) rows of a query result."""
393        return self.fetchmany(-1, False)
394
395    def fetchmany(self, size=None, keep=False):
396        """Fetch the next set of rows of a query result.
397
398        The number of rows to fetch per call is specified by the
399        size parameter. If it is not given, the cursor's arraysize
400        determines the number of rows to be fetched. If you set
401        the keep parameter to true, this is kept as new arraysize.
402
403        """
404        if size is None:
405            size = self.arraysize
406        if keep:
407            self.arraysize = size
408        try:
409            result = self._src.fetch(size)
410        except DatabaseError:
411            raise
412        except Error as err:
413            raise _db_error(str(err))
414        typecast = self._type_cache.typecast
415        return [self.row_factory([typecast(typ, value)
416            for typ, value in zip(self.coltypes, row)]) for row in result]
417
418    def __next__(self):
419        """Return the next row (support for the iteration protocol)."""
420        res = self.fetchone()
421        if res is None:
422            raise StopIteration
423        return res
424
425    # Note that since Python 2.6 the iterator protocol uses __next()__
426    # instead of next(), we keep it only for backward compatibility of pgdb.
427    next = __next__
428
429    @staticmethod
430    def nextset():
431        """Not supported."""
432        raise NotSupportedError("nextset() is not supported")
433
434    @staticmethod
435    def setinputsizes(sizes):
436        """Not supported."""
437        pass  # unsupported, but silently passed
438
439    @staticmethod
440    def setoutputsize(size, column=0):
441        """Not supported."""
442        pass  # unsupported, but silently passed
443
444    @staticmethod
445    def row_factory(row):
446        """Process rows before they are returned.
447
448        You can overwrite this statically with a custom row factory, or
449        you can build a row factory dynamically with build_row_factory().
450
451        For example, you can create a Cursor class that returns rows as
452        Python dictionaries like this:
453
454            class DictCursor(pgdb.Cursor):
455
456                def row_factory(self, row):
457                    return {desc[0]: value
458                        for desc, value in zip(self.description, row)}
459
460            cur = DictCursor(con)  # get one DictCursor instance or
461            con.cursor_type = DictCursor  # always use DictCursor instances
462
463        """
464        raise NotImplementedError
465
466    def build_row_factory(self):
467        """Build a row factory based on the current description.
468
469        This implementation builds a row factory for creating named tuples.
470        You can overwrite this method if you want to dynamically create
471        different row factories whenever the column description changes.
472
473        """
474        colnames = self.colnames
475        if colnames:
476            try:
477                try:
478                    return namedtuple('Row', colnames, rename=True)._make
479                except TypeError:  # Python 2.6 and 3.0 do not support rename
480                    colnames = [v if v.isalnum() else 'column_%d' % n
481                             for n, v in enumerate(colnames)]
482                    return namedtuple('Row', colnames)._make
483            except ValueError:  # there is still a problem with the field names
484                colnames = ['column_%d' % n for n in range(len(colnames))]
485                return namedtuple('Row', colnames)._make
486
487
488CursorDescription = namedtuple('CursorDescription',
489    ['name', 'type_code', 'display_size', 'internal_size',
490     'precision', 'scale', 'null_ok'])
491
492
493### Connection Objects
494
495class Connection(object):
496    """Connection object."""
497
498    # expose the exceptions as attributes on the connection object
499    Error = Error
500    Warning = Warning
501    InterfaceError = InterfaceError
502    DatabaseError = DatabaseError
503    InternalError = InternalError
504    OperationalError = OperationalError
505    ProgrammingError = ProgrammingError
506    IntegrityError = IntegrityError
507    DataError = DataError
508    NotSupportedError = NotSupportedError
509
510    def __init__(self, cnx):
511        """Create a database connection object."""
512        self._cnx = cnx  # connection
513        self._tnx = False  # transaction state
514        self._type_cache = TypeCache(cnx)
515        self.cursor_type = Cursor
516        try:
517            self._cnx.source()
518        except Exception:
519            raise _op_error("invalid connection")
520
521    def __enter__(self):
522        """Enter the runtime context for the connection object.
523
524        The runtime context can be used for running transactions.
525
526        """
527        return self
528
529    def __exit__(self, et, ev, tb):
530        """Exit the runtime context for the connection object.
531
532        This does not close the connection, but it ends a transaction.
533
534        """
535        if et is None and ev is None and tb is None:
536            self.commit()
537        else:
538            self.rollback()
539
540    def close(self):
541        """Close the connection object."""
542        if self._cnx:
543            if self._tnx:
544                try:
545                    self.rollback()
546                except DatabaseError:
547                    pass
548            self._cnx.close()
549            self._cnx = None
550        else:
551            raise _op_error("connection has been closed")
552
553    def commit(self):
554        """Commit any pending transaction to the database."""
555        if self._cnx:
556            if self._tnx:
557                self._tnx = False
558                try:
559                    self._cnx.source().execute("COMMIT")
560                except DatabaseError:
561                    raise
562                except Exception:
563                    raise _op_error("can't commit")
564        else:
565            raise _op_error("connection has been closed")
566
567    def rollback(self):
568        """Roll back to the start of any pending transaction."""
569        if self._cnx:
570            if self._tnx:
571                self._tnx = False
572                try:
573                    self._cnx.source().execute("ROLLBACK")
574                except DatabaseError:
575                    raise
576                except Exception:
577                    raise _op_error("can't rollback")
578        else:
579            raise _op_error("connection has been closed")
580
581    def cursor(self):
582        """Return a new cursor object using the connection."""
583        if self._cnx:
584            try:
585                return self.cursor_type(self)
586            except Exception:
587                raise _op_error("invalid connection")
588        else:
589            raise _op_error("connection has been closed")
590
591    if shortcutmethods:  # otherwise do not implement and document this
592
593        def execute(self, operation, params=None):
594            """Shortcut method to run an operation on an implicit cursor."""
595            cursor = self.cursor()
596            cursor.execute(operation, params)
597            return cursor
598
599        def executemany(self, operation, param_seq):
600            """Shortcut method to run an operation against a sequence."""
601            cursor = self.cursor()
602            cursor.executemany(operation, param_seq)
603            return cursor
604
605
606### Module Interface
607
608_connect_ = connect
609
610def connect(dsn=None,
611        user=None, password=None,
612        host=None, database=None):
613    """Connects to a database."""
614    # first get params from DSN
615    dbport = -1
616    dbhost = ""
617    dbbase = ""
618    dbuser = ""
619    dbpasswd = ""
620    dbopt = ""
621    try:
622        params = dsn.split(":")
623        dbhost = params[0]
624        dbbase = params[1]
625        dbuser = params[2]
626        dbpasswd = params[3]
627        dbopt = params[4]
628    except (AttributeError, IndexError, TypeError):
629        pass
630
631    # override if necessary
632    if user is not None:
633        dbuser = user
634    if password is not None:
635        dbpasswd = password
636    if database is not None:
637        dbbase = database
638    if host is not None:
639        try:
640            params = host.split(":")
641            dbhost = params[0]
642            dbport = int(params[1])
643        except (AttributeError, IndexError, TypeError, ValueError):
644            pass
645
646    # empty host is localhost
647    if dbhost == "":
648        dbhost = None
649    if dbuser == "":
650        dbuser = None
651
652    # open the connection
653    cnx = _connect_(dbbase, dbhost, dbport, dbopt, dbuser, dbpasswd)
654    return Connection(cnx)
655
656
657### Types Handling
658
659class Type(frozenset):
660    """Type class for a couple of PostgreSQL data types.
661
662    PostgreSQL is object-oriented: types are dynamic.
663    We must thus use type names as internal type codes.
664
665    """
666
667    def __new__(cls, values):
668        if isinstance(values, basestring):
669            values = values.split()
670        return super(Type, cls).__new__(cls, values)
671
672    def __eq__(self, other):
673        if isinstance(other, basestring):
674            return other in self
675        else:
676            return super(Type, self).__eq__(other)
677
678    def __ne__(self, other):
679        if isinstance(other, basestring):
680            return other not in self
681        else:
682            return super(Type, self).__ne__(other)
683
684
685# Mandatory type objects defined by DB-API 2 specs:
686
687STRING = Type('char bpchar name text varchar')
688BINARY = Type('bytea')
689NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money')
690DATETIME = Type('date time timetz timestamp timestamptz datetime abstime'
691    ' interval tinterval timespan reltime')
692ROWID = Type('oid oid8')
693
694
695# Additional type objects (more specific):
696
697BOOL = Type('bool')
698SMALLINT = Type('int2')
699INTEGER = Type('int2 int4 int8 serial')
700LONG = Type('int8')
701FLOAT = Type('float4 float8')
702NUMERIC = Type('numeric')
703MONEY = Type('money')
704DATE = Type('date')
705TIME = Type('time timetz')
706TIMESTAMP = Type('timestamp timestamptz datetime abstime')
707INTERVAL = Type('interval tinterval timespan reltime')
708
709
710# Mandatory type helpers defined by DB-API 2 specs:
711
712def Date(year, month, day):
713    """Construct an object holding a date value."""
714    return date(year, month, day)
715
716
717def Time(hour, minute=0, second=0, microsecond=0):
718    """Construct an object holding a time value."""
719    return time(hour, minute, second, microsecond)
720
721
722def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0):
723    """construct an object holding a time stamp value."""
724    return datetime(year, month, day, hour, minute, second, microsecond)
725
726
727def DateFromTicks(ticks):
728    """Construct an object holding a date value from the given ticks value."""
729    return Date(*localtime(ticks)[:3])
730
731
732def TimeFromTicks(ticks):
733    """construct an object holding a time value from the given ticks value."""
734    return Time(*localtime(ticks)[3:6])
735
736
737def TimestampFromTicks(ticks):
738    """construct an object holding a time stamp from the given ticks value."""
739    return Timestamp(*localtime(ticks)[:6])
740
741
742class Binary(bytes):
743    """construct an object capable of holding a binary (long) string value."""
744
745
746# If run as script, print some information:
747
748if __name__ == '__main__':
749    print('PyGreSQL version', version)
750    print('')
751    print(__doc__)
Note: See TracBrowser for help on using the repository browser.