diff --git a/epu/epumanagement/reactor.py b/epu/epumanagement/reactor.py index 0e242b82..32b191fe 100644 --- a/epu/epumanagement/reactor.py +++ b/epu/epumanagement/reactor.py @@ -6,7 +6,7 @@ from epu.epumanagement.decider import DEFAULT_ENGINE_CLASS from epu.states import InstanceState, InstanceHealthState from epu.domain_log import EpuLoggerThreadSpecific -from epu.exceptions import NotFoundError +from epu.exceptions import NotFoundError, WriteConflictError from epu.util import get_class log = logging.getLogger(__name__) @@ -233,10 +233,12 @@ def new_instance_state(self, content): if instance_id: domain = self.store.get_domain_for_instance_id(instance_id) if domain: - log.debug("Got state %s for instance '%s'", state, instance_id) - instance = domain.get_instance(instance_id) - if domain.new_instance_state(content, previous=instance): + # retry update in case of write conflict + instance, updated = self._maybe_update_domain_instance( + domain, instance_id, content) + if updated: + log.debug("Got state %s for instance '%s'", state, instance_id) # The higher level clients of EPUM only see RUNNING or FAILED (or nothing) if content['state'] < InstanceState.RUNNING: @@ -257,6 +259,18 @@ def new_instance_state(self, content): else: log.error("Could not parse instance ID from state message: '%s'" % content) + def _maybe_update_domain_instance(self, domain, instance_id, msg): + while True: + instance = domain.get_instance(instance_id) + content = copy.deepcopy(msg) + try: + updated = domain.new_instance_state(content, previous=instance) + return instance, updated + except WriteConflictError: + pass + except NotFoundError: + return instance, False + def new_heartbeat(self, caller, content, timestamp=None): """Handle an incoming heartbeat message diff --git a/epu/epumanagement/store.py b/epu/epumanagement/store.py index a6294f81..c734755c 100644 --- a/epu/epumanagement/store.py +++ b/epu/epumanagement/store.py @@ -343,6 +343,7 @@ def new_instance_state(self, content, timestamp=None, previous=None): return True # instance was probably a duplicate return False + return False def mark_instance_terminating(self, instance_id): """Mark an instance for termination diff --git a/epu/epumanagement/test/test_controller_core.py b/epu/epumanagement/test/test_controller_core.py index 67d0f6e6..0b03269d 100644 --- a/epu/epumanagement/test/test_controller_core.py +++ b/epu/epumanagement/test/test_controller_core.py @@ -2,15 +2,19 @@ import itertools import uuid import unittest +import threading + +from mock import Mock from epu.decisionengine.impls.simplest import CONF_PRESERVE_N from epu.epumanagement.conf import * # noqa -from epu.epumanagement.store import LocalDomainStore +from epu.epumanagement.store import LocalDomainStore, ZooKeeperDomainStore from epu.states import InstanceState, InstanceHealthState from epu.epumanagement.decider import ControllerCoreControl from epu.epumanagement.core import EngineState, CoreInstance from epu.epumanagement.test.mocks import MockProvisionerClient -from epu.test import Mock +from epu.test import ZooKeeperTestMixin +from epu.exceptions import WriteConflictError log = logging.getLogger(__name__) @@ -95,6 +99,54 @@ def test_instances(self): self.assertEqual(len(all_instances), 1) self.assertIn(instance_id, all_instances) + def test_instance_update_conflict(self): + launch_id = str(uuid.uuid4()) + instance_id = str(uuid.uuid4()) + self.domain.new_instance_launch("dtid", instance_id, launch_id, + "chicago", "big", timestamp=1) + + sneaky_msg = dict(node_id=instance_id, launch_id=launch_id, + site="chicago", allocation="big", + state=InstanceState.PENDING) + + # patch in a function that sneaks in an instance record update just + # before a requested update. This simulates the case where two EPUM + # workers are competing to update the same instance. + original_update_instance = self.domain.update_instance + + patch_called = threading.Event() + + def patched_update_instance(*args, **kwargs): + patch_called.set() + # unpatch ourself first so we don't recurse forever + self.domain.update_instance = original_update_instance + + self.domain.new_instance_state(sneaky_msg, timestamp=2) + original_update_instance(*args, **kwargs) + self.domain.update_instance = patched_update_instance + + # send our "real" update. should get a conflict + msg = dict(node_id=instance_id, launch_id=launch_id, + site="chicago", allocation="big", + state=InstanceState.STARTED) + + with self.assertRaises(WriteConflictError): + self.domain.new_instance_state(msg, timestamp=2) + + assert patch_called.is_set() + + +class ZooKeeperControllerStateStoreTests(ControllerStateStoreTests, ZooKeeperTestMixin): + + # this runs all of the ControllerStateStoreTests tests plus any + # ZK-specific ones + + def setUp(self): + self.setup_zookeeper("/epum_store_tests_") + self.addCleanup(self.teardown_zookeeper) + self.domain = ZooKeeperDomainStore("david", "domain1", self.kazoo, + self.kazoo.retry, self.zk_base_path) + class ControllerCoreStateTests(BaseControllerStateTests): """ControllerCoreState tests that only use in memory store @@ -109,7 +161,7 @@ def test_instance_extravars(self): (when they don't arrive in state updates) """ - extravars = {'iwant': 'asandwich', 4: 'real'} + extravars = {'iwant': 'asandwich', '4': 'real'} launch_id, instance_id = self.new_instance(1, extravars=extravars) self.new_instance_state(launch_id, instance_id, @@ -205,6 +257,18 @@ def test_out_of_order_instance(self): InstanceState.STARTED) +class ZooKeeperControllerCoreStateStoreTests(ControllerCoreStateTests, ZooKeeperTestMixin): + + # this runs all of the ControllerCoreStateTests tests plus any + # ZK-specific ones + + def setUp(self): + self.setup_zookeeper("/epum_store_tests_") + self.addCleanup(self.teardown_zookeeper) + self.domain = ZooKeeperDomainStore("david", "domain1", self.kazoo, + self.kazoo.retry, self.zk_base_path) + + class EngineStateTests(unittest.TestCase): def test_instances(self): diff --git a/epu/epumanagement/test/test_epumanagement.py b/epu/epumanagement/test/test_epumanagement.py index f9a42fd7..cbfdd5ae 100644 --- a/epu/epumanagement/test/test_epumanagement.py +++ b/epu/epumanagement/test/test_epumanagement.py @@ -2,6 +2,7 @@ import unittest import logging import time +import threading from epu.decisionengine.impls.simplest import CONF_PRESERVE_N from epu.epumanagement import EPUManagement @@ -729,3 +730,92 @@ def test_reaper(self): self.assertIn("n5", instances) self.assertIn("n6", instances) self.assertIn("n7", instances) + + def test_instance_update_conflict_1(self): + + self.epum.initialize() + domain_config = self._config_simplest_domainconf(1) + definition = {} + self.epum.msg_add_domain_definition("definition1", definition) + self.epum.msg_add_domain("owner1", "testing123", "definition1", domain_config) + self.epum._run_decisions() + self.assertEqual(self.provisioner_client.provision_count, 1) + + domain = self.epum_store.get_domain("owner1", "testing123") + + instance_id = self.provisioner_client.launched_instance_ids[0] + launch_id = self.provisioner_client.launches[0]['launch_id'] + + sneaky_msg = dict(node_id=instance_id, state=InstanceState.PENDING) + + # patch in a function that sneaks in an instance record update just + # before a requested update. This simulates the case where two EPUM + # workers are competing to update the same instance. + original_new_instance_state = domain.new_instance_state + + patch_called = threading.Event() + + def patched_new_instance_state(content, timestamp=None, previous=None): + patch_called.set() + + # unpatch ourself first so we don't recurse forever + domain.new_instance_state = original_new_instance_state + + domain.new_instance_state(sneaky_msg, previous=previous) + return domain.new_instance_state(content, timestamp=timestamp, previous=previous) + domain.new_instance_state = patched_new_instance_state + + # send our "real" update. should get a conflict + msg = dict(node_id=instance_id, state=InstanceState.STARTED) + + self.epum.msg_instance_info("owner1", msg) + + assert patch_called.is_set() + + instance = domain.get_instance(instance_id) + self.assertEqual(instance.state, InstanceState.STARTED) + + def test_instance_update_conflict_2(self): + + self.epum.initialize() + domain_config = self._config_simplest_domainconf(1) + definition = {} + self.epum.msg_add_domain_definition("definition1", definition) + self.epum.msg_add_domain("owner1", "testing123", "definition1", domain_config) + self.epum._run_decisions() + self.assertEqual(self.provisioner_client.provision_count, 1) + + domain = self.epum_store.get_domain("owner1", "testing123") + + instance_id = self.provisioner_client.launched_instance_ids[0] + launch_id = self.provisioner_client.launches[0]['launch_id'] + + sneaky_msg = dict(node_id=instance_id, state=InstanceState.STARTED) + + # patch in a function that sneaks in an instance record update just + # before a requested update. This simulates the case where two EPUM + # workers are competing to update the same instance. + original_new_instance_state = domain.new_instance_state + + patch_called = threading.Event() + def patched_new_instance_state(content, timestamp=None, previous=None): + patch_called.set() + + # unpatch ourself first so we don't recurse forever + domain.new_instance_state = original_new_instance_state + + domain.new_instance_state(sneaky_msg, previous=previous) + return domain.new_instance_state(content, timestamp=timestamp, previous=previous) + domain.new_instance_state = patched_new_instance_state + + # send our "real" update. should get a conflict + msg = dict(node_id=instance_id, state=InstanceState.PENDING) + + self.epum.msg_instance_info(None, msg) + + assert patch_called.is_set() + + # in this case the sneaky message (STARTED) should win because it is + # the later state + instance = domain.get_instance(instance_id) + self.assertEqual(instance.state, InstanceState.STARTED) diff --git a/epu/provisioner/test/util.py b/epu/provisioner/test/util.py index b31e2538..f0441aba 100644 --- a/epu/provisioner/test/util.py +++ b/epu/provisioner/test/util.py @@ -13,9 +13,9 @@ from libcloud.compute.base import NodeDriver, Node, NodeSize from libcloud.compute.types import NodeState +from mock import Mock from epu.provisioner.ctx import ContextResource -from epu.test import Mock from epu.states import InstanceState import dashi.bootstrap diff --git a/epu/test/__init__.py b/epu/test/__init__.py index 1458dc86..7ba691cf 100644 --- a/epu/test/__init__.py +++ b/epu/test/__init__.py @@ -14,17 +14,6 @@ log = logging.getLogger(__name__) -class Mock(object): - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - - def __repr__(self): - return self.__str__() - - def __str__(self): - return "Mock(" + ",".join("%s=%s" % (k, v) for k, v in self.__dict__.iteritems()) + ")" - - class MockLeader(object): def __init__(self):