Changeset 735 for trunk/pg.py


Ignore:
Timestamp:
Jan 13, 2016, 4:49:35 PM (4 years ago)
Author:
cito
Message:

Implement "upsert" method for PostgreSQL 9.5

A new method upsert() has been added to the DB wrapper class that
nicely complements the existing get/insert/update/delete() methods.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/pg.py

    r734 r735  
    514514                raise KeyError('Class %s has no primary key' % cl)
    515515            if len(pkey) > 1:
    516                 pkey = frozenset([k[0] for k in pkey])
     516                pkey = frozenset(k[0] for k in pkey)
    517517            else:
    518518                pkey = pkey[0][0]
     
    625625        if cl.endswith('*'):  # scan descendant tables?
    626626            cl = cl[:-1].rstrip()  # need parent table name
    627         # build qualified class name
    628         # To allow users to work with multiple tables,
    629         # we munge the name of the "oid" key
    630         qoid = _oid_key(cl)
    631627        if not keyname:
    632628            # use the primary key by default
     
    639635        param = partial(self._prepare_param, params=params)
    640636        col = self.escape_identifier
    641         # We want the oid for later updates if that isn't the key
     637        # We want the oid for later updates if that isn't the key.
     638        # To allow users to work with multiple tables, we munge
     639        # the name of the "oid" key by adding the name of the class.
     640        qoid = _oid_key(cl)
    642641        if keyname == 'oid':
    643642            if isinstance(arg, dict):
     
    649648            where = 'oid = %s' % param(arg[qoid], 'int')
    650649        else:
    651             if isinstance(keyname, basestring):
    652                 keyname = (keyname,)
     650            keyname = [keyname] if isinstance(
     651                keyname, basestring) else sorted(keyname)
    653652            if not isinstance(arg, dict):
    654653                if len(keyname) > 1:
     
    656655                arg = dict((k, arg) for k in keyname)
    657656            what = ', '.join(col(k) for k in attnames)
    658             where = ' AND '.join(['%s = %s'
    659                 % (col(k), param(arg[k], attnames[k])) for k in keyname])
     657            where = ' AND '.join('%s = %s' % (
     658                col(k), param(arg[k], attnames[k])) for k in keyname)
    660659        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
    661660            what, self._escape_qualified_name(cl), where)
    662661        self._do_debug(q, params)
    663         res = self.db.query(q, params).dictresult()
     662        q = self.db.query(q, params)
     663        res = q.dictresult()
    664664        if not res:
    665665            raise _db_error('No such record in %s where %s' % (cl, where))
     
    689689
    690690        """
    691         qoid = _oid_key(cl)
     691        if 'oid' in kw:
     692            del kw['oid']
    692693        if d is None:
    693694            d = {}
     
    699700        names, values = [], []
    700701        for n in attnames:
    701             if n != 'oid' and n in d:
     702            if n in d:
    702703                names.append(col(n))
    703704                values.append(param(d[n], attnames[n]))
     
    707708            self._escape_qualified_name(cl), names, values, ret)
    708709        self._do_debug(q, params)
    709         res = self.db.query(q, params)
    710         res = res.dictresult()[0]
    711         for n, value in res.items():
     710        q = self.db.query(q, params)
     711        res = q.dictresult()
     712        if not res:
     713            raise _int_error('insert did not return new values')
     714        for n, value in res[0].items():
    712715            if n == 'oid':
    713                 n = qoid
     716                n = _oid_key(cl)
    714717            elif attnames.get(n) == 'bytea' and value is not None:
    715718                value = self.unescape_bytea(value)
     
    727730
    728731        """
    729         # Update always works on the oid which get returns if available,
     732        # Update always works on the oid which get() returns if available,
    730733        # otherwise use the primary key.  Fail if neither.
    731         # Note that we only accept oid key from named args for safety
     734        # Note that we only accept oid key from named args for safety.
    732735        qoid = _oid_key(cl)
    733736        if 'oid' in kw:
     
    743746        if qoid in d:
    744747            where = 'oid = %s' % param(d[qoid], 'int')
    745             keyname = ()
     748            keyname = []
    746749        else:
    747750            try:
     
    749752            except KeyError:
    750753                raise _prg_error('Class %s has no primary key' % cl)
    751             if isinstance(keyname, basestring):
    752                 keyname = (keyname,)
     754            keyname = [keyname] if isinstance(
     755                keyname, basestring) else sorted(keyname)
    753756            try:
    754                 where = ' AND '.join(['%s = %s'
    755                     % (col(k), param(d[k], attnames[k])) for k in keyname])
     757                where = ' AND '.join('%s = %s' % (
     758                    col(k), param(d[k], attnames[k])) for k in keyname)
    756759            except KeyError:
    757                 raise _prg_error('Update needs primary key or oid.')
     760                raise _prg_error('update needs primary key or oid')
     761        keyname = set(keyname)
     762        keyname.add('oid')
    758763        values = []
    759764        for n in attnames:
     
    767772            self._escape_qualified_name(cl), values, where, ret)
    768773        self._do_debug(q, params)
    769         res = self.db.query(q, params)
    770         res = res.dictresult()[0]
    771         for n, value in res.items():
    772             if n == 'oid':
    773                 n = qoid
    774             elif attnames.get(n) == 'bytea' and value is not None:
    775                 value = self.unescape_bytea(value)
    776             d[n] = value
     774        q = self.db.query(q, params)
     775        res = q.dictresult()
     776        if res:  # may be empty when row does not exist
     777            for n, value in res[0].items():
     778                if n == 'oid':
     779                    n = qoid
     780                elif attnames.get(n) == 'bytea' and value is not None:
     781                    value = self.unescape_bytea(value)
     782                d[n] = value
     783        return d
     784
     785    def upsert(self, cl, d=None, **kw):
     786        """Insert a row into a database table with conflict resolution.
     787
     788        This method inserts a row into a table, but instead of raising a
     789        ProgrammingError exception in case a row with the same primary key
     790        already exists, an update will be executed instead.  This will be
     791        performed as a single atomic operation on the database, so race
     792        conditions can be avoided.
     793
     794        Like the insert method, the first parameter is the name of the
     795        table and the second parameter can be used to pass the values to
     796        be inserted as a dictionary.
     797
     798        Unlike the insert und update statement, keyword parameters are not
     799        used to modify the dictionary, but to specify which columns shall
     800        be updated in case of a conflict, and in which way:
     801
     802        A value of False or None means the column shall not be updated,
     803        a value of True means the column shall be updated with the value
     804        that has been proposed for insertion, i.e. has been passed as value
     805        in the dictionary.  Columns that are not specified by keywords but
     806        appear as keys in the dictionary are also updated like in the case
     807        keywords had been passed with the value True.
     808
     809        So if in the case of a conflict you want to update every column that
     810        has been passed in the dictionary d , you would call upsert(cl, d).
     811        If you don't want to do anything in case of a conflict, i.e. leave
     812        the existing row as it is, call upsert(cl, d, **dict.fromkeys(d)).
     813
     814        If you need more fine-grained control of what gets updated, you can
     815        also pass strings in the keyword parameters.  These strings will
     816        be used as SQL expressions for the update columns.  In these
     817        expressions you can refer to the value that already exists in
     818        the table by prefixing the column name with "included.", and to
     819        the value that has been proposed for insertion by prefixing the
     820        column name with the "excluded."
     821
     822        The dictionary is modified in any case to reflect the values in
     823        the database after the operation has completed.
     824
     825        Note: The method uses the PostgreSQL "upsert" feature which is
     826        only available since PostgreSQL 9.5.
     827
     828        """
     829        if 'oid' in kw:
     830            del kw['oid']
     831        if d is None:
     832            d = {}
     833        attnames = self.get_attnames(cl)
     834        params = []
     835        param = partial(self._prepare_param,params=params)
     836        col = self.escape_identifier
     837        names, values, updates = [], [], []
     838        for n in attnames:
     839            if n in d:
     840                names.append(col(n))
     841                values.append(param(d[n], attnames[n]))
     842        names, values = ', '.join(names), ', '.join(values)
     843        try:
     844            keyname = self.pkey(cl)
     845        except KeyError:
     846            raise _prg_error('Class %s has no primary key' % cl)
     847        keyname = [keyname] if isinstance(
     848            keyname, basestring) else sorted(keyname)
     849        try:
     850            target = ', '.join(col(k) for k in keyname)
     851        except KeyError:
     852            raise _prg_error('upsert needs primary key or oid')
     853        update = []
     854        keyname = set(keyname)
     855        keyname.add('oid')
     856        for n in attnames:
     857            if n not in keyname:
     858                value = kw.get(n, True)
     859                if value:
     860                    if not isinstance(value, basestring):
     861                        value = 'excluded.%s' % col(n)
     862                    update.append('%s = %s' % (col(n), value))
     863        if not values and not update:
     864            return d
     865        do = 'update set %s' % ', '.join(update) if update else 'nothing'
     866        ret = 'oid, *' if 'oid' in attnames else '*'
     867        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
     868            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
     869                self._escape_qualified_name(cl), names, values,
     870                target, do, ret)
     871        self._do_debug(q, params)
     872        try:
     873            q = self.db.query(q, params)
     874        except ProgrammingError:
     875            if self.server_version < 90500:
     876                raise _prg_error('upsert not supported by PostgreSQL version')
     877            raise  # re-raise original error
     878        res = q.dictresult()
     879        if res:  # may be empty with "do nothing"
     880            for n, value in res[0].items():
     881                if n == 'oid':
     882                    n = _oid_key(cl)
     883                elif attnames.get(n) == 'bytea':
     884                    value = self.unescape_bytea(value)
     885                d[n] = value
     886        elif update:
     887            raise _int_error('upsert did not return new values')
     888        else:
     889            self.get(cl, d)
    777890        return d
    778891
     
    815928        # One day we will be testing that the record to be deleted
    816929        # isn't referenced somewhere (or else PostgreSQL will).
    817         # Note that we only accept oid key from named args for safety
     930        # Note that we only accept oid key from named args for safety.
    818931        qoid = _oid_key(cl)
    819932        if 'oid' in kw:
     
    832945            except KeyError:
    833946                raise _prg_error('Class %s has no primary key' % cl)
    834             if isinstance(keyname, basestring):
    835                 keyname = (keyname,)
     947            keyname = [keyname] if isinstance(
     948                keyname, basestring) else sorted(keyname)
    836949            attnames = self.get_attnames(cl)
    837950            col = self.escape_identifier
    838951            try:
    839                 where = ' AND '.join(['%s = %s'
    840                     % (col(k), param(d[k], attnames[k])) for k in keyname])
     952                where = ' AND '.join('%s = %s'
     953                    % (col(k), param(d[k], attnames[k])) for k in keyname)
    841954            except KeyError:
    842                 raise _prg_error('Delete needs primary key or oid.')
     955                raise _prg_error('delete needs primary key or oid')
    843956        q = 'DELETE FROM %s WHERE %s' % (
    844957            self._escape_qualified_name(cl), where)
    845958        self._do_debug(q, params)
    846         return int(self.db.query(q, params))
     959        res = self.db.query(q, params)
     960        return int(res)
    847961
    848962    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
Note: See TracChangeset for help on using the changeset viewer.