Changeset 391


Ignore:
Timestamp:
Dec 5, 2008, 10:00:19 AM (11 years ago)
Author:
cito
Message:

The insert() and update() methods now check if the table is selectable and use the "returning" clause if possible. The delete() method now also works with primary keys and returns whether the row existed.

Location:
trunk
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • trunk/docs/changelog.txt

    r389 r391  
    3434  submitting an empty command.
    3535- Removed compatibility code for old OID munging style.
    36 - Added "return_changes" flag to insert method.
     36- The insert() and update() methods now use the "returning" clause
     37  if possible to get all changed values, and they also check in advance
     38  whether a subsequent select is possible, so that ongoing transactions
     39  won't break if there is no select privilege.
    3740- Added "protocol_version" and "server_version" attributes.
    3841- Revived the "user" attribute.
     
    4346- get() raises a nicer ProgrammingError instead of a KeyError
    4447  if no primary key was found.
     48- delete() now also works based on the primary key if no oid available
     49  and returns whether the row existed or not.
     50
    4551
    4652Version 3.8.1 (2006-06-05)
  • trunk/docs/future.txt

    r389 r391  
    77-----
    88
    9 - Classic pg relies on OIDs, but these are not generated by default any more
    10   (at least docs should recommend setting default_with_oids=true).
    11 - The insert() method in pg should use a `returning *` clause instead of
    12   the subsequent get() when the PostgreSQL version is >= 8.2.
    13 - Before getting the values back after insert() and update(), check
    14   with has_table_privilege(table, 'select') whether this is possible.
    15   Otherwise we might break an ongoing transaction.
    169- Documentation for the pgdb module (everything specific to PyGreSQL).
    1710- The large object and direct access functions need much more attention.
  • trunk/docs/pg.txt

    r389 r391  
    845845  Given the name of a table, digs out the set of attribute names.
    846846
     847has_table_privilege - check whether current user has specified table privilege
     848------------------------------------------------------------------------------
     849Syntax::
     850
     851    has_table_privilege(table, privilege)
     852
     853Parameters:
     854  :table:     name of table
     855  :privilege: privilege to be checked - default is 'select'
     856
     857Description:
     858  Returns True if the current user has the specified privilege for the table.
     859
    847860get - get a row from a database table or view
    848861---------------------------------------------
     
    891904  added to or replace the entry in the dictionary.
    892905
    893   The dictionary is then reloaded with the values actually inserted
    894   in order to pick up values modified by rules, triggers, etc.  If
    895   the optional flag return_changes is set to False this reload will
    896   be skipped.
     906  The dictionary is then, if possible, reloaded with the values actually
     907  inserted in order to pick up values modified by rules, triggers, etc.
    897908
    898909  Due to the way that this function works you will find inserts taking
     
    919930
    920931Description:
    921   Similar to insert but updates an existing row. The update is based
    922   on the OID value as munged by get. The array returned is the
    923   one sent modified to reflect any changes caused by the update due
    924   to triggers, rules, defaults, etc.
     932  Similar to insert but updates an existing row.  The update is based on the
     933  OID value as munged by get or passed as keyword, or on the primary key of
     934  the table.  The dictionary is modified, if possible, to reflect any changes
     935  caused by the update due to triggers, rules, default values, etc.
    925936
    926937  Like insert, the dictionary is optional and updates will be performed
     
    965976
    966977Description:
    967   This method deletes the row from a table. It deletes based on the OID
    968   as munged as described above.  Alternatively, the keyword "oid" can
    969   be used to specify the OID.
     978  This method deletes the row from a table.  It deletes based on the OID value
     979  as munged by get or passed as keyword, or on the primary key of the table.
     980  The return value is the number of deleted rows (i.e. 0 if the row did not
     981  exist and 1 if the row was deleted).
    970982
    971983escape_string - escape a string for use within SQL
  • trunk/module/pg.py

    r390 r391  
    66# Improved by Christoph Zwerschke
    77#
    8 # $Id: pg.py,v 1.75 2008-12-05 02:08:15 cito Exp $
     8# $Id: pg.py,v 1.76 2008-12-05 15:00:19 cito Exp $
    99#
    1010
     
    128128        self._attnames = {}
    129129        self._pkeys = {}
     130        self._privileges = {}
    130131        self._args = args, kw
    131132        self.debug = None # For debugging scripts, this can be set
     
    227228            cl = s[0]
    228229            # determine search path
    229             query = 'SELECT current_schemas(TRUE)'
    230             schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
     230            q = 'SELECT current_schemas(TRUE)'
     231            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
    231232            if schemas: # non-empty path
    232233                # search schema for this object in the current search path
    233                 query = ' UNION '.join(
     234                q = ' UNION '.join(
    234235                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
    235236                        % s for s in enumerate(schemas)])
    236                 query = ("SELECT nspname FROM pg_class"
     237                q = ("SELECT nspname FROM pg_class"
    237238                    " JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid"
    238239                    " JOIN (%s) AS p USING (nspname)"
    239240                    " WHERE pg_class.relname = '%s'"
    240                     " ORDER BY n LIMIT 1" % (query, cl))
    241                 schema = self.db.query(query).getresult()
     241                    " ORDER BY n LIMIT 1" % (q, cl))
     242                schema = self.db.query(q).getresult()
    242243                if schema: # schema found
    243244                    schema = schema[0][0]
     
    295296
    296297        This method simply sends a SQL query to the database. If the query is
    297         an insert statement, the return value is the OID of the newly
    298         inserted row.  If it is otherwise a query that does not return a result
    299         (ie. is not a some kind of SELECT statement), it returns None.
    300         Otherwise, it returns a pgqueryobject that can be accessed via the
    301         getresult or dictresult method or simply printed.
     298        an insert statement that inserted exactly one row into a table that
     299        has OIDs, the return value is the OID of the newly inserted row.
     300        If the query is an update or delete statement, or an insert statement
     301        that did not insert exactly one row in a table with OIDs, then the
     302        numer of rows affected is returned as a string. If it is a statement
     303        that returns rows as a result (usually a select statement, but maybe
     304        also an "insert/update ... returning" statement), this method returns
     305        a pgqueryobject that can be accessed via getresult() or dictresult()
     306        or simply printed. Otherwise, it returns `None`.
    302307
    303308        """
     
    447452        return self._attnames[qcl]
    448453
     454    def has_table_privilege(self, cl, privilege='select'):
     455        """Check whether current user has specified table privilege."""
     456        qcl = self._add_schema(cl)
     457        privilege = privilege.lower()
     458        try:
     459            return self._privileges[(qcl, privilege)]
     460        except KeyError:
     461            q = "SELECT has_table_privilege('%s', '%s')" % (qcl, privilege)
     462            ret = self.db.query(q).getresult()[0][0] == 't'
     463            self._privileges[(qcl, privilege)] = ret
     464            return ret
     465
    449466    def get(self, cl, arg, keyname=None):
    450467        """Get a tuple from a database table or view.
     
    503520        return arg
    504521
    505     def insert(self, cl, d=None, return_changes=True, **kw):
     522    def insert(self, cl, d=None, **kw):
    506523        """Insert a tuple into a database table.
    507524
     
    510527        Either way the dictionary is updated from the keywords.
    511528
    512         The dictionary is then reloaded with the values actually inserted
    513         in order to pick up values modified by rules, triggers, etc.  If
    514         the optional flag return_changes is set to False this reload will
    515         be skipped.
     529        The dictionary is then, if possible, reloaded with the values actually
     530        inserted in order to pick up values modified by rules, triggers, etc.
    516531
    517532        Note: The method currently doesn't support insert into views
     
    519534
    520535        """
     536        qcl = self._add_schema(cl)
     537        qoid = _oid_key(qcl)
    521538        if d is None:
    522539            d = {}
    523540        d.update(kw)
    524         qcl = self._add_schema(cl)
    525         qoid = _oid_key(qcl)
    526541        attnames = self.get_attnames(qcl)
    527542        names, values = [], []
     
    531546                values.append(self._quote(d[n], attnames[n]))
    532547        names, values = ', '.join(names), ', '.join(values)
    533         q = 'INSERT INTO %s (%s) VALUES (%s)' % (qcl, names, values)
     548        selectable = self.has_table_privilege(qcl)
     549        if selectable and self.server_version >= 80200:
     550            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
     551        else:
     552            ret = ''
     553        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
    534554        self._do_debug(q)
    535         d[qoid] = self.db.query(q)
    536         # Reload the dictionary to catch things modified by engine.
    537         # Note that get() changes 'oid' below to oid(schema.table).
    538         if return_changes:
    539             self.get(qcl, d, 'oid')
     555        res = self.db.query(q)
     556        if ret:
     557            res = res.dictresult()
     558            for att, value in res[0].iteritems():
     559                d[att == 'oid' and qoid or att] = value
     560        elif isinstance(res, int):
     561            d[qoid] = res
     562            if selectable:
     563                self.get(qcl, d, 'oid')
     564        elif selectable:
     565            if qoid in d:
     566                self.get(qcl, d, 'oid')
     567            else:
     568                try:
     569                    self.get(qcl, d)
     570                except ProgrammingError:
     571                    pass # table has no primary key
    540572        return d
    541573
     
    544576
    545577        Similar to insert but updates an existing row.  The update is based
    546         on the OID value as munged by get.  The array returned is the
    547         one sent modified to reflect any changes caused by the update due
    548         to triggers, rules, defaults, etc.
     578        on the OID value as munged by get or passed as keyword, or on the
     579        primary key of the table.  The dictionary is modified, if possible,
     580        to reflect any changes caused by the update due to triggers, rules,
     581        default values, etc.
    549582
    550583        """
     
    552585        # otherwise use the primary key.  Fail if neither.
    553586        # Note that we only accept oid key from named args for safety
     587        qcl = self._add_schema(cl)
     588        qoid = _oid_key(qcl)
    554589        if 'oid' in kw:
    555590            kw[qoid] = kw['oid']
     
    558593            d = {}
    559594        d.update(kw)
    560         qcl = self._add_schema(cl)
    561         qoid = _oid_key(qcl)
    562595        attnames = self.get_attnames(qcl)
    563596        if qoid in d:
     
    566599        else:
    567600            try:
    568                keyname = self.pkey(qcl)
     601                keyname = self.pkey(qcl)
    569602            except KeyError:
    570                raise ProgrammingError('Class %s has no primary key' % qcl)
     603                raise ProgrammingError('Class %s has no primary key' % qcl)
    571604            if isinstance(keyname, basestring):
    572605                keyname = (keyname,)
     
    583616            return d
    584617        values = ', '.join(values)
    585         q = 'UPDATE %s SET %s WHERE %s' % (qcl, values, where)
     618        selectable = self.has_table_privilege(qcl)
     619        if selectable and self.server_version >= 880200:
     620            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
     621        else:
     622            ret = ''
     623        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
    586624        self._do_debug(q)
    587         self.db.query(q)
    588         # Reload the dictionary to catch things modified by engine:
    589         if qoid in d:
    590             return self.get(qcl, d, 'oid')
    591         else:
    592             return self.get(qcl, d)
     625        res = self.db.query(q)
     626        if ret:
     627            res = self.db.query(q).dictresult()
     628            for att, value in res[0].iteritems():
     629                d[att == 'oid' and qoid or att] = value
     630        else:
     631            self.db.query(q)
     632            if selectable:
     633                if qoid in d:
     634                    self.get(qcl, d, 'oid')
     635                else:
     636                    self.get(qcl, d)
     637        return d
    593638
    594639    def clear(self, cl, a=None):
     
    603648        """
    604649        # At some point we will need a way to get defaults from a table.
     650        qcl = self._add_schema(cl)
    605651        if a is None:
    606652            a = {} # empty if argument is not present
    607         qcl = self._add_schema(cl)
    608653        attnames = self.get_attnames(qcl)
    609654        for n, t in attnames.iteritems():
     
    621666        """Delete an existing row in a database table.
    622667
    623         This method deletes the row from a table.
    624         It deletes based on the OID munged as described above.
     668        This method deletes the row from a table.  It deletes based on the
     669        OID value as munged by get or passed as keyword, or on the primary
     670        key of the table.  The return value is the number of deleted rows
     671        (i.e. 0 if the row did not exist and 1 if the row was deleted).
    625672
    626673        """
     
    629676        # isn't referenced somewhere (or else PostgreSQL will).
    630677        # Note that we only accept oid key from named args for safety
     678        qcl = self._add_schema(cl)
     679        qoid = _oid_key(qcl)
    631680        if 'oid' in kw:
    632681            kw[qoid] = kw['oid']
     
    635684            d = {}
    636685        d.update(kw)
    637         qcl = self._add_schema(cl)
    638         qoid = _oid_key(qcl)
    639         q = 'DELETE FROM %s WHERE oid=%s' % (qcl, d[qoid])
     686        if qoid in d:
     687            where = 'oid = %s' % d[qoid]
     688        else:
     689            try:
     690                keyname = self.pkey(qcl)
     691            except KeyError:
     692                raise ProgrammingError('Class %s has no primary key' % qcl)
     693            if isinstance(keyname, basestring):
     694                keyname = (keyname,)
     695            attnames = self.get_attnames(qcl)
     696            try:
     697                where = ' AND '.join(['%s = %s'
     698                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
     699            except KeyError:
     700                raise ProgrammingError('Delete needs primary key or oid.')
     701        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
    640702        self._do_debug(q)
    641         self.db.query(q)
     703        return int(self.db.query(q))
    642704
    643705
  • trunk/module/test_pg.py

    r389 r391  
    55# Written by Christoph Zwerschke
    66#
    7 # $Id: test_pg.py,v 1.26 2008-12-05 02:05:28 cito Exp $
     7# $Id: test_pg.py,v 1.27 2008-12-05 15:00:19 cito Exp $
    88#
    99
     
    1313
    1414There are a few drawbacks:
    15 * A local PostgreSQL database must be up and running, and
    16 the user who is running the tests must be a trusted superuser.
    17 * The performance of the API is not tested.
    18 * Connecting to a remote host is not tested.
    19 * Passing user, password and options is not tested.
    20 * Status and error messages from the connection are not tested.
    21 * It would be more reasonable to create a test for the underlying
    22 shared library functions in the _pg module and assume they are ok.
    23 The pg and pgdb modules should be tested against _pg mock functions.
     15  * A local PostgreSQL database must be up and running, and
     16    the user who is running the tests must be a trusted superuser.
     17  * The performance of the API is not tested.
     18  * Connecting to a remote host is not tested.
     19  * Passing user, password and options is not tested.
     20  * Table privilege problems (e.g. insert but no select) are not tested.
     21  * Status and error messages from the connection are not tested.
     22  * It would be more reasonable to create a test for the underlying
     23    shared library functions in the _pg module and assume they are ok.
     24    The pg and pgdb modules should be tested against _pg mock functions.
    2425
    2526"""
     
    356357
    357358    def testAllConnectAttributes(self):
    358         attributes = ['db', 'error', 'host', 'options', 'port',
    359             'protocol_version', 'server_version', 'status', 'tty', 'user']
     359        attributes = '''db error host options port
     360            protocol_version server_version status tty user'''.split()
    360361        connection_attributes = [a for a in dir(self.connection)
    361362            if not callable(eval("self.connection." + a))]
     
    363364
    364365    def testAllConnectMethods(self):
    365         methods = ['cancel', 'close', 'endcopy', 'escape_bytea',
    366             'escape_string', 'fileno', 'getline', 'getlo', 'getnotify',
    367             'inserttable', 'locreate', 'loimport', 'parameter', 'putline',
    368             'query', 'reset', 'source', 'transaction']
     366        methods = '''cancel close endcopy escape_bytea escape_string fileno
     367            getline getlo getnotify inserttable locreate loimport parameter
     368            putline query reset source transaction'''.split()
    369369        connection_methods = [a for a in dir(self.connection)
    370370            if callable(eval("self.connection." + a))]
     
    443443
    444444    def setUp(self):
    445         dbname = 'template1'
     445        dbname = 'test'
    446446        self.c = pg.connect(dbname)
    447447
     
    693693
    694694    def testAllDBAttributes(self):
    695         attributes = ['cancel', 'clear', 'close', 'db', 'dbname', 'debug',
    696             'delete', 'endcopy', 'error', 'escape_bytea', 'escape_string',
    697             'fileno', 'get', 'get_attnames', 'get_databases', 'get_relations',
    698             'get_tables', 'getline', 'getlo', 'getnotify', 'host', 'insert',
    699             'inserttable', 'locreate', 'loimport', 'options', 'parameter',
    700             'pkey', 'port', 'protocol_version', 'putline', 'query', 'reopen',
    701             'reset', 'server_version', 'source', 'status', 'transaction',
    702             'tty', 'unescape_bytea', 'update', 'user']
     695        attributes = '''cancel clear close db dbname debug delete endcopy
     696            error escape_bytea escape_string fileno  get get_attnames
     697            get_databases get_relations get_tables getline getlo getnotify
     698            has_table_privilege host insert inserttable locreate loimport
     699            options parameter pkey port protocol_version putline query
     700            reopen reset server_version source status transaction tty
     701            unescape_bytea update user'''.split()
    703702        db_attributes = [a for a in dir(self.db)
    704703            if not a.startswith('_')]
     
    10611060            self.assertEqual(attributes, result)
    10621061
     1062    def testHasTablePrivilege(self):
     1063        can = self.db.has_table_privilege
     1064        self.assertEqual(can('test'), True)
     1065        self.assertEqual(can('test', 'select'), True)
     1066        self.assertEqual(can('test', 'SeLeCt'), True)
     1067        self.assertEqual(can('test', 'SELECT'), True)
     1068        self.assertEqual(can('test', 'insert'), True)
     1069        self.assertEqual(can('test', 'update'), True)
     1070        self.assertEqual(can('test', 'delete'), True)
     1071        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
     1072        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
     1073
    10631074    def testGet(self):
    10641075        for table in ('get_test_table', 'test table for get'):
     
    11791190
    11801191    def testUpdateWithCompositeKey(self):
    1181         table = 'get_update_table_1'
     1192        table = 'update_test_table_1'
    11821193        smart_ddl(self.db, "drop table %s" % table)
    11831194        smart_ddl(self.db, "create table %s ("
     
    11921203            ).getresult()[0][0]
    11931204        self.assertEqual(r, 'd')
    1194         table = 'get_test_update_2'
     1205        table = 'update_test_table_2'
    11951206        smart_ddl(self.db, "drop table %s" % table)
    11961207        smart_ddl(self.db, "create table %s ("
     
    12361247            r = self.db.get(table, 1, 'n')
    12371248            s = self.db.delete(table, r)
     1249            self.assertEqual(s, 1)
    12381250            r = self.db.get(table, 3, 'n')
    12391251            s = self.db.delete(table, r)
     1252            self.assertEqual(s, 1)
     1253            s = self.db.delete(table, r)
     1254            self.assertEqual(s, 0)
    12401255            r = self.db.query('select * from "%s"' % table).dictresult()
    12411256            self.assertEqual(len(r), 1)
     
    12451260            r = self.db.get(table, 2, 'n')
    12461261            s = self.db.delete(table, r)
     1262            self.assertEqual(s, 1)
     1263            s = self.db.delete(table, r)
     1264            self.assertEqual(s, 0)
    12471265            self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
     1266
     1267    def testDeleteWithCompositeKey(self):
     1268        table = 'delete_test_table_1'
     1269        smart_ddl(self.db, "drop table %s" % table)
     1270        smart_ddl(self.db, "create table %s ("
     1271            "n integer, t text, primary key (n))" % table)
     1272        for n, t in enumerate('abc'):
     1273            self.db.query("insert into %s values("
     1274                "%d, '%s')" % (table, n+1, t))
     1275        self.assertRaises(pg.ProgrammingError, self.db.delete,
     1276            table, dict(t='b'))
     1277        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
     1278        r = self.db.query('select t from "%s" where n=2' % table
     1279            ).getresult()
     1280        self.assertEqual(r, [])
     1281        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
     1282        r = self.db.query('select t from "%s" where n=3' % table
     1283            ).getresult()[0][0]
     1284        self.assertEqual(r, 'c')
     1285        table = 'delete_test_table_2'
     1286        smart_ddl(self.db, "drop table %s" % table)
     1287        smart_ddl(self.db, "create table %s ("
     1288            "n integer, m integer, t text, primary key (n, m))" % table)
     1289        for n in range(3):
     1290            for m in range(2):
     1291                t = chr(ord('a') + 2*n +m)
     1292                self.db.query("insert into %s values("
     1293                    "%d, %d, '%s')" % (table, n+1, m+1, t))
     1294        self.assertRaises(pg.ProgrammingError, self.db.delete,
     1295            table, dict(n=2, t='b'))
     1296        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
     1297        r = [r[0] for r in self.db.query('select t from "%s" where n=2'
     1298            ' order by m' % table).getresult()]
     1299        self.assertEqual(r, ['c'])
     1300        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
     1301        r = [r[0] for r in self.db.query('select t from "%s" where n=3'
     1302            ' order by m' % table).getresult()]
     1303        self.assertEqual(r, ['e', 'f'])
     1304        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
     1305        r = [r[0] for r in self.db.query('select t from "%s" where n=3'
     1306            ' order by m' % table).getresult()]
     1307        self.assertEqual(r, ['f'])
    12481308
    12491309    def testBytea(self):
Note: See TracChangeset for help on using the changeset viewer.