# This program is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this program; if not, write to the # Free Software Foundation, Inc., 59 Temple Place - Suite 330, # Boston, MA 02111-1307, USA. """\ UnitTests for the Transaction class See also: https://harald.hoyer.xyz/2015/10/13/a-python-transaction-class/ Copyright (C) 2008 Harald Hoyer Copyright (C) 2008 Red Hat, Inc. """ import unittest import sys import copy from transaction import Transaction class TransactionOld1(object): """\ Old Transaction class from http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/284677 """ def __init__(self): self.log = [] def commit(self, **kwargs): self.log.append(self.__dict__.copy()) def rollback(self, **kwargs): try: self.__dict__.update(self.log.pop(-1)) except IndexError: pass def __repr__(self): return "'self.__dict__ = %s'" % self.__dict__ class TransactionOld2(TransactionOld1): def commit(self, **kwargs): self.log.append(copy.deepcopy(self.__dict__)) class TransactionOld3(TransactionOld2): def rollback(self, **kwargs): try: state = self.log.pop(-1) self.__dict__.clear() self.__dict__.update(state) except IndexError: pass class TransactionNew1(object): def _docommit(self): if "log" not in self.__dict__: self.__dict__["log"] = list() self.__dict__["log"].append(copy.deepcopy(self.__dict__)) def _dorollback(self): if "log" not in self.__dict__: return try: state = self.__dict__["log"].pop(-1) self.__dict__.clear() self.__dict__.update(state) except IndexError: pass def commit(self, **kwargs): # commit ourselves, then our childs self._docommit() if kwargs.get("deep", True): for child in self.__dict__.values(): if isinstance(child, self.__class__): child.commit() def rollback(self, **kwargs): # rollback our childs, then ourselves if kwargs.get("deep", True): for child in self.__dict__.values(): if isinstance(child, self.__class__): child.rollback() self._dorollback() def __repr__(self): return "'self.__dict__ = %s'" % self.__dict__ class TransactionNew2(TransactionNew1): def _docommit(self): if "log" in self.__dict__: oldstate = self.__dict__.pop("log") else: oldstate = None state = copy.deepcopy(self.__dict__) if oldstate: state["log"] = oldstate self.__dict__["log"] = state def _dorollback(self): if "log" not in self.__dict__: return try: state = self.__dict__["log"] self.__dict__.clear() self.__dict__.update(state) except IndexError: pass def __repr__(self): return "'self.__dict__ = %s'" % self.__dict__ class TransactionNew3(TransactionNew2): def _checksetseen(self, seen): if id(self) in seen: sys.stderr.write("Recursion detected... ") return True seen.add(id(self)) return False def commit(self, **kwargs): # pylint: disable-msg=W0613 seen = kwargs.get("_commit_seen", set()) if self._checksetseen(seen): return # commit ourselves, then our childs self._docommit() if kwargs.get("deep", True): for child in self.__dict__.values(): if isinstance(child, self.__class__): child.commit(_commit_seen=seen) def rollback(self, **kwargs): seen = kwargs.get("_rollback_seen", set()) if self._checksetseen(seen): return # rollback our childs, then ourselves if kwargs.get("deep", True): for child in self.__dict__.values(): if isinstance(child, self.__class__): child.rollback(_rollback_seen=seen) self._dorollback() class TransactionImproved(Transaction): def __repr__(self): return "'self.__dict__ = %s'" % self.__dict__ class TestTransaction(unittest.TestCase): doRecursion = False def test01(self): "simple rollback" a = TestClass() a.test = "correct" a.commit() a.test = "roll me back" a.rollback() self.assertEqual(a.test, "correct") def test02(self): """double rollback demonstrates how you can commit / rollback several times """ a = TestClass() a.test = "correct" a.commit() a.test = "roll me back second" a.commit() a.test = "roll me back first" a.rollback() a.rollback() self.assertEqual(a.test, "correct") def test03(self): """\ test list rollback showing the side effects of a non deep rollback """ a = TestClass() a.ls = [0, 1, 2] a.commit(deep=False) a.ls.append(3) a.rollback(deep=False) self.assertEqual(a.ls, [0, 1, 2]) def test031(self): """\ non deep rollback showing the side effects of a non deep rollback """ a = TestClass() a.ckls = TestClass() b = a.ckls a.ckls.newvar = "correct" a.commit(deep=False) a.ckls.newvar = "roll me back" a.rollback(deep=False) self.assertEqual(b.newvar, "roll me back") self.assertEqual(a.ckls.newvar, "roll me back") def test04(self): """commit not deep, rollback deep showing the side effects of a non deep commit """ a = TestClass() a.ckls = TestClass() a.ckls.newvar = TestClass() a.ckls.newvar.text = "correct" b = a.ckls a.commit(deep=False) a.ckls.newvar.text = "roll me back" a.rollback(deep=True) self.assertEqual(b.newvar.text, "roll me back") self.assertEqual(a.ckls.newvar.text, "roll me back") def test041(self): """check for leftover attributes""" a = TestClass() a.newvar = "correct" a.commit() a.shouldnotbethere = True a.rollback() self.failIf(hasattr(a, "shouldnotbethere"), a) def test05(self): """commit and rollback deep no more side effects """ a = TestClass() a.ckls = TestClass() a.ckls.newvar = "correct" b = a.ckls a.commit() a.ckls.newvar = "roll me back" a.rollback(deep=True) self.assertEqual(a.ckls.newvar, "correct") self.assertEqual(b.newvar, "correct") self.assertEqual(id(a.ckls), id(b)) def test06(self): """commit only a sub object though we committed only an attribute, the deep rollback will roll it back. """ a = TestClass() a.ckls = TestClass() a.newvar = "correct" a.ckls.newvar = "correct" a.ckls.commit() a.newvar = "will not be rolled back" a.ckls.newvar = "roll me back" a.rollback() self.assertEqual(a.newvar, "will not be rolled back") self.assertEqual(a.ckls.newvar, "correct") def test07(self): """commit only a sub object, rollback with deep=false we committed only an attribute and the non deep rollback will not roll it back. """ a = TestClass() a.ckls = TestClass() a.newvar = "correct" a.ckls.newvar = "correct" a.ckls.commit() a.newvar = "will not be rolled back" a.ckls.newvar = "will not be rolled back" a.rollback(deep=False) self.assertEqual(a.newvar, "will not be rolled back") self.assertEqual(a.ckls.newvar, "will not be rolled back") def test10(self): """check for the commit/rollback recursion""" if not TestTransaction.doRecursion: sys.stderr.write("skipped .. ") return a = TestClass() b = a for i in xrange(10): b.newvar = TestClass() b.test = "test" + str(i) b = b.newvar b.newvar = a # would raise a recursion maximum exception a.commit(deep=True) def test11(self): """check for swapping Transaction objects""" a = TestClass() a.t1 = TestClass() a.t2 = TestClass() a.t1.text = "test1" a.t2.text = "test2" a.commit() b = a.t1 a.t1 = a.t2 a.t2 = b a.rollback() self.assertEqual(a.t1.text, "test1") self.assertEqual(a.t2.text, "test2") def count_dict(self, seen, what): if id(what) in seen: return 1 seen.add(id(what)) if not hasattr(what, "__dict__"): return 1 i = 1 for val in what.__dict__.values(): # print "Counting ", id(val), val i = i + self.count_dict(seen, val) return i def test91(self): """check for the stack length (deep=False)""" a = TestClass() b = a for i in xrange(10): b.newvar = TestClass() b.test = "test" + str(i) b = b.newvar a.commit(deep=False) a.commit(deep=False) a.commit(deep=False) seen = set() count = self.count_dict(seen, a) sys.stderr.write("%d objects in %d places .. " % (len(seen), count)) def test92(self): """check for the stack length (deep=True)""" a = TestClass() b = a for i in xrange(10): b.newvar = TestClass() b.test = "test" + str(i) b.test2 = "test2" + str(i) b = b.newvar a.commit(deep=True) a.commit() a.commit() a.commit() seen = set() count = self.count_dict(seen, a) sys.stderr.write("%d objects in %d places .. " % (len(seen), count)) def suite(): _suite = unittest.TestSuite() _suite = unittest.makeSuite(TestTransaction, "test") return _suite if __name__ == "__main__": global TestClass # print >> sys.stderr, """\ # ********************************************************************** # Old Transaction Class (original) # ********************************************************************** # """ # TestClass = TransactionOld1 # testrunner = unittest.TextTestRunner(verbosity=2) # result = testrunner.run(suite()) # print >> sys.stderr, """\ # ********************************************************************** # Old Transaction Class (copy.deepcopy) # ********************************************************************** # """ # TestClass = TransactionOld2 # testrunner = unittest.TextTestRunner(verbosity=2) # result = testrunner.run(suite()) # # print >> sys.stderr, """\ # ********************************************************************** # Old Transaction Class (copy.deepcopy + __dict__.clear) # ********************************************************************** # """ # TestClass = TransactionOld3 # testrunner = unittest.TextTestRunner(verbosity=2) # result = testrunner.run(suite()) # # print >> sys.stderr, """\ # ********************************************************************** # New Transaction Class (deep commit) # ********************************************************************** # """ # TestClass = TransactionNew1 # testrunner = unittest.TextTestRunner(verbosity=2) # result = testrunner.run(suite()) # # print >> sys.stderr, """\ # ********************************************************************** # New Transaction Class (deep commit + stack improvement) # ********************************************************************** # """ # TestClass = TransactionNew2 # testrunner = unittest.TextTestRunner(verbosity=2) # result = testrunner.run(suite()) # # print >> sys.stderr, """\ # ********************************************************************** # New Transaction Class (deep commit + stack improvement + recursion check) # ********************************************************************** # """ # TestClass = TransactionNew3 # TestTransaction.doRecursion = True # testrunner = unittest.TextTestRunner(verbosity=2) # result = testrunner.run(suite()) print >> sys.stderr, """\ ********************************************************************** New Transaction Class (final) ********************************************************************** """ TestClass = TransactionImproved TestTransaction.doRecursion = True testrunner = unittest.TextTestRunner(verbosity=2) result = testrunner.run(suite()) sys.exit(not result.wasSuccessful())