diff --git a/src/Parallel/Callback.hpp b/src/Parallel/Callback.hpp index 18c9084eab14..f12aaa506526 100644 --- a/src/Parallel/Callback.hpp +++ b/src/Parallel/Callback.hpp @@ -6,15 +6,45 @@ #pragma once +#include #include #include #include #include "Parallel/Invoke.hpp" +#include "Utilities/PrettyType.hpp" #include "Utilities/Serialization/CharmPupable.hpp" #include "Utilities/Serialization/RegisterDerivedClassesWithCharm.hpp" +#include "Utilities/TypeTraits/HasEquivalence.hpp" namespace Parallel { +namespace detail { +// Not all tuple arguments are guaranteed to have operator==, so we check the +// ones we can. +template +bool tuple_equal(const std::tuple& tuple_1, + const std::tuple& tuple_2) { + bool result = true; + tmpl::for_each, + tmpl::size>::value>>( + [&](const auto index_v) { + constexpr size_t index = tmpl::type_from::value; + + if (not result) { + return; + } + + if constexpr (tt::has_equivalence_v( + tuple_1))>) { + result = + result and std::get(tuple_1) == std::get(tuple_2); + } + }); + + return result; +} +} // namespace detail + /// An abstract base class, whose derived class holds a function that /// can be invoked at a later time. The function is intended to be /// invoked only once. @@ -30,6 +60,12 @@ class Callback : public PUP::able { explicit Callback(CkMigrateMessage* msg) : PUP::able(msg) {} virtual void invoke() = 0; virtual void register_with_charm() = 0; + /*! + * \brief Returns if this callback is equal to the one passed in. + */ + virtual bool is_equal_to(const Callback& rhs) const = 0; + virtual std::string name() const = 0; + virtual std::unique_ptr get_clone() = 0; }; /// Wraps a call to a simple action and its arguments. @@ -65,6 +101,27 @@ class SimpleActionCallback : public Callback { register_classes_with_charm(); } + bool is_equal_to(const Callback& rhs) const override { + const auto* downcast_ptr = dynamic_cast(&rhs); + if (downcast_ptr == nullptr) { + return false; + } + return detail::tuple_equal(args_, downcast_ptr->args_); + } + + std::string name() const override { + // Use pretty_type::get_name with the action since we want to differentiate + // template paremeters. Only use pretty_type::name for proxy because it'll + // likely be really long with the template parameters which is unnecessary + return "SimpleActionCallback(" + pretty_type::get_name() + + "," + pretty_type::name() + ")"; + } + + std::unique_ptr get_clone() override { + return std::make_unique>( + *this); + } + private: std::decay_t proxy_{}; std::tuple...> args_{}; @@ -93,6 +150,23 @@ class SimpleActionCallback : public Callback { register_classes_with_charm(); } + bool is_equal_to(const Callback& rhs) const override { + const auto* downcast_ptr = dynamic_cast(&rhs); + return downcast_ptr != nullptr; + } + + std::string name() const override { + // Use pretty_type::get_name with the action since we want to differentiate + // template paremeters. Only use pretty_type::name for proxy because it'll + // likely be really long with the template parameters which is unnecessary + return "SimpleActionCallback(" + pretty_type::get_name() + + "," + pretty_type::name() + ")"; + } + + std::unique_ptr get_clone() override { + return std::make_unique>(*this); + } + private: std::decay_t proxy_{}; }; @@ -130,6 +204,28 @@ class ThreadedActionCallback : public Callback { register_classes_with_charm(); } + bool is_equal_to(const Callback& rhs) const override { + const auto* downcast_ptr = + dynamic_cast(&rhs); + if (downcast_ptr == nullptr) { + return false; + } + return detail::tuple_equal(args_, downcast_ptr->args_); + } + + std::string name() const override { + // Use pretty_type::get_name with the action since we want to differentiate + // template paremeters. Only use pretty_type::name for proxy because it'll + // likely be really long with the template parameters which is unnecessary + return "ThreadedActionCallback(" + pretty_type::get_name() + + "," + pretty_type::name() + ")"; + } + + std::unique_ptr get_clone() override { + return std::make_unique< + ThreadedActionCallback>(*this); + } + private: std::decay_t proxy_{}; std::tuple...> args_{}; @@ -158,6 +254,25 @@ class ThreadedActionCallback : public Callback { register_classes_with_charm(); } + bool is_equal_to(const Callback& rhs) const override { + const auto* downcast_ptr = + dynamic_cast(&rhs); + return downcast_ptr != nullptr; + } + + std::string name() const override { + // Use pretty_type::get_name with the action since we want to differentiate + // template paremeters. Only use pretty_type::name for proxy because it'll + // likely be really long with the template parameters which is unnecessary + return "ThreadedActionCallback(" + pretty_type::get_name() + + "," + pretty_type::name() + ")"; + } + + std::unique_ptr get_clone() override { + return std::make_unique>( + *this); + } + private: std::decay_t proxy_{}; }; @@ -184,6 +299,22 @@ class PerformAlgorithmCallback : public Callback { register_classes_with_charm(); } + bool is_equal_to(const Callback& rhs) const override { + const auto* downcast_ptr = + dynamic_cast(&rhs); + return downcast_ptr != nullptr; + } + + std::string name() const override { + // Only use pretty_type::name for proxy because it'll likely be really long + // with the template parameters which is unnecessary + return "PerformAlgorithmCallback(" + pretty_type::name() + ")"; + } + + std::unique_ptr get_clone() override { + return std::make_unique>(*this); + } + private: std::decay_t proxy_{}; }; @@ -194,7 +325,7 @@ template PUP::able::PUP_ID PerformAlgorithmCallback::my_PUP_ID = 0; template PUP::able::PUP_ID -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) SimpleActionCallback::my_PUP_ID = 0; // NOLINT template @@ -203,7 +334,7 @@ PUP::able::PUP_ID SimpleActionCallback::my_PUP_ID = 0; // NOLINT template PUP::able::PUP_ID -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) ThreadedActionCallback::my_PUP_ID = 0; // NOLINT template diff --git a/src/Parallel/GlobalCache.hpp b/src/Parallel/GlobalCache.hpp index bd2ce341a585..8417c4b59cb0 100644 --- a/src/Parallel/GlobalCache.hpp +++ b/src/Parallel/GlobalCache.hpp @@ -28,6 +28,7 @@ #include "Parallel/ParallelComponentHelpers.hpp" #include "Parallel/ResourceInfo.hpp" #include "Parallel/Tags/ResourceInfo.hpp" +#include "Utilities/Algorithm.hpp" #include "Utilities/ErrorHandling/Assert.hpp" #include "Utilities/ErrorHandling/Error.hpp" #include "Utilities/Gsl.hpp" @@ -77,10 +78,10 @@ CREATE_GET_TYPE_ALIAS_OR_DEFAULT(component_being_mocked) template auto make_mutable_cache_tag_storage(tuples::TaggedTuple&& input) { - return tuples::TaggedTuple...>( - std::make_tuple(std::move(tuples::get(input)), - std::unordered_map>{})...); + return tuples::TaggedTuple...>(std::make_tuple( + std::move(tuples::get(input)), + std::unordered_map>>{})...); } template @@ -487,14 +488,34 @@ bool GlobalCache::mutable_cache_item_is_ready( optional_callback->register_with_charm(); // Second mutex is for vector of callbacks std::mutex& mutex = tuples::get>(mutexes_).second; + const std::unique_ptr clone_of_optional_callback = + optional_callback->get_clone(); { // Scoped for lock guard const std::lock_guard lock(mutex); - std::unordered_map>& - callbacks = std::get<1>(tuples::get(mutable_global_cache_)); - - if (callbacks.count(array_component_id) != 1) { - callbacks[array_component_id] = std::move(optional_callback); + std::unordered_map>>& callbacks = + std::get<1>(tuples::get(mutable_global_cache_)); + + if (callbacks.contains(array_component_id)) { + // If this array component id already exists, we don't want to add + // multiple of the same callback, so we loop over the existing callbacks + // and only if none of the existing callbacks are equal to the optional + // callback do we move the optional callback into the vector + auto& vec_callbacks = callbacks.at(array_component_id); + if (alg::none_of(vec_callbacks, + [&](const std::unique_ptr& local_callback) { + return local_callback->is_equal_to( + *optional_callback); + })) { + vec_callbacks.emplace_back(std::move(optional_callback)); + } + } else { + // If we don't have this array component id, then we create the vector + // and move the optional callback into the vector + callbacks[array_component_id] = + std::vector>(1); + callbacks.at(array_component_id)[0] = std::move(optional_callback); } } @@ -531,10 +552,26 @@ bool GlobalCache::mutable_cache_item_is_ready( const bool cache_item_is_ready = not callback_was_registered(); if (cache_item_is_ready) { const std::lock_guard lock(mutex); - std::unordered_map>& - callbacks = std::get<1>(tuples::get(mutable_global_cache_)); - - callbacks.erase(array_component_id); + std::unordered_map>>& callbacks = + std::get<1>(tuples::get(mutable_global_cache_)); + + // It's possible that no new callbacks were registered, so make sure this + // array component id still has callbacks before trying to remove them. + if (callbacks.contains(array_component_id)) { + // If this callback was a duplicate, we'll have to search through all + // callbacks to determine which to remove. If it wasn't a duplicate, + // then it'll just be the last callback in the vector. + auto& vec_callbacks = callbacks.at(array_component_id); + std::erase_if(vec_callbacks, + [&clone_of_optional_callback](const auto& t) { + return t->is_equal_to(*clone_of_optional_callback); + }); + + if (callbacks.at(array_component_id).empty()) { + callbacks.erase(array_component_id); + } + } } return cache_item_is_ready; @@ -573,7 +610,8 @@ void GlobalCache::mutate(const std::tuple& args) { // Therefore, after locking it, we std::move the map of callbacks into a // temporary map, clear the original map, and invoke the callbacks in the // temporary map. - std::unordered_map> + std::unordered_map>> callbacks{}; // Second mutex is for map of callbacks std::mutex& mutex = tuples::get>(mutexes_).second; @@ -587,9 +625,10 @@ void GlobalCache::mutate(const std::tuple& args) { // Invoke the callbacks. Any new callbacks that are added to the // list (if a callback calls mutable_cache_item_is_ready) will be // saved and will not be invoked here. - for (auto& [array_component_id, callback] : callbacks) { - (void)array_component_id; - callback->invoke(); + for (auto& [array_component_id, vec_callbacks] : callbacks) { + for (auto& callback : vec_callbacks) { + callback->invoke(); + } } } diff --git a/src/Parallel/ParallelComponentHelpers.hpp b/src/Parallel/ParallelComponentHelpers.hpp index 3aaa78981308..e7b36d34ab87 100644 --- a/src/Parallel/ParallelComponentHelpers.hpp +++ b/src/Parallel/ParallelComponentHelpers.hpp @@ -191,9 +191,10 @@ struct MutexTag { template struct MutableCacheTag { using tag = Tag; - using type = std::tuple>>; + using type = + std::tuple>>>; }; template diff --git a/tests/Unit/Parallel/Test_Callback.cpp b/tests/Unit/Parallel/Test_Callback.cpp index 135eb22f7946..dab686bb4510 100644 --- a/tests/Unit/Parallel/Test_Callback.cpp +++ b/tests/Unit/Parallel/Test_Callback.cpp @@ -247,15 +247,34 @@ struct RunCallbacks { Parallel::SimpleActionCallback callback_2(proxy_2, 1.5); + SPECTRE_PARALLEL_REQUIRE( + callback_0.name().find("PerformAlgorithmCallback") != + std::string::npos); + SPECTRE_PARALLEL_REQUIRE( + (callback_1.name().find("SimpleActionCallback") != std::string::npos and + callback_1.name().find("IncrementValue") != std::string::npos)); + SPECTRE_PARALLEL_REQUIRE( + (callback_2.name().find("SimpleActionCallback") != std::string::npos and + callback_2.name().find("MultiplyValueByFactor") != std::string::npos)); callback_0.invoke(); callback_1.invoke(); callback_2.invoke(); auto callback_3 = serialize_and_deserialize(callback_0); auto callback_4 = serialize_and_deserialize(callback_1); auto callback_5 = serialize_and_deserialize(callback_2); + SPECTRE_PARALLEL_REQUIRE(callback_0.is_equal_to(callback_3)); + SPECTRE_PARALLEL_REQUIRE_FALSE(callback_0.is_equal_to(callback_4)); + SPECTRE_PARALLEL_REQUIRE(callback_1.is_equal_to(callback_4)); + SPECTRE_PARALLEL_REQUIRE_FALSE(callback_1.is_equal_to(callback_5)); callback_3.invoke(); callback_4.invoke(); callback_5.invoke(); + const auto callback_6 = callback_0.get_clone(); + const auto callback_7 = callback_1.get_clone(); + const auto callback_8 = callback_2.get_clone(); + SPECTRE_PARALLEL_REQUIRE(callback_0.is_equal_to(*callback_6)); + SPECTRE_PARALLEL_REQUIRE(callback_1.is_equal_to(*callback_7)); + SPECTRE_PARALLEL_REQUIRE(callback_2.is_equal_to(*callback_8)); std::vector> callbacks; callbacks.emplace_back( std::make_unique>( @@ -267,6 +286,7 @@ struct RunCallbacks { callbacks.emplace_back( std::make_unique>(proxy_2, 2.0)); + SPECTRE_PARALLEL_REQUIRE_FALSE(callback_2.is_equal_to(*callbacks.back())); auto& nodegroup_proxy = Parallel::get_parallel_component>(cache); diff --git a/tests/Unit/Parallel/Test_GlobalCache.ci b/tests/Unit/Parallel/Test_GlobalCache.ci index f62d6d133dc8..3076b86c0e9e 100644 --- a/tests/Unit/Parallel/Test_GlobalCache.ci +++ b/tests/Unit/Parallel/Test_GlobalCache.ci @@ -18,5 +18,6 @@ mainmodule Test_GlobalCache { entry void run_test_three(); entry void run_test_four(); entry void run_test_five(); + entry void mutate_name(); }; } diff --git a/tests/Unit/Parallel/Test_GlobalCache.cpp b/tests/Unit/Parallel/Test_GlobalCache.cpp index 86cbec5279a8..875b62ba9625 100644 --- a/tests/Unit/Parallel/Test_GlobalCache.cpp +++ b/tests/Unit/Parallel/Test_GlobalCache.cpp @@ -152,7 +152,7 @@ struct modify_number_of_legs { template struct SingletonParallelComponent { using chare_type = Parallel::Algorithms::Singleton; - using const_global_cache_tags = tmpl::list; + using const_global_cache_tags = tmpl::list; using mutable_global_cache_tags = tmpl::list; using metavariables = Metavariables; using phase_dependent_action_list = tmpl::list< @@ -177,8 +177,8 @@ struct ArrayParallelComponent { template struct GroupParallelComponent { using chare_type = Parallel::Algorithms::Group; - using const_global_cache_tags = tmpl::list; - using mutable_global_cache_tags = tmpl::list; + using const_global_cache_tags = tmpl::list<>; + using mutable_global_cache_tags = tmpl::list; using metavariables = Metavariables; using phase_dependent_action_list = tmpl::list< Parallel::PhaseActions>>; @@ -219,23 +219,54 @@ class UseCkCallbackAsCallback : public Parallel::Callback { UseCkCallbackAsCallback() = default; explicit UseCkCallbackAsCallback(CkMigrateMessage* msg) : Parallel::Callback(msg) {} - explicit UseCkCallbackAsCallback(const CkCallback& callback) - : callback_(callback) {} + explicit UseCkCallbackAsCallback(const CkCallback& callback, + const size_t index) + : callback_(callback), index_(index) {} using PUP::able::register_constructor; void invoke() override { callback_.send(nullptr); } void pup(PUP::er& p) override { p | callback_; } // We shouldn't be pupping so registration doesn't matter void register_with_charm() override {} + bool is_equal_to(const Parallel::Callback& rhs) const override { + const auto* downcast_ptr = + dynamic_cast(&rhs); + if (downcast_ptr == nullptr) { + return false; + } + return index_ == downcast_ptr->index_; + } + + std::string name() const override { + return "UseCkCallbackAsCallback" + std::to_string(index_); + } + + std::unique_ptr get_clone() override { + return std::make_unique(callback_, index_); + } private: CkCallback callback_; + size_t index_{}; }; +size_t calls_to_test_one = 0; // NOLINT + +template +void TestArrayChare::mutate_name() { + auto& local_cache = *Parallel::local_branch(global_cache_proxy_); + + // Turn Nobody into Somebody + Parallel::mutate>(local_cache, + std::string("Somebody")); +} + template void TestArrayChare::run_test_one() { // Test that the values are what we think they should be. auto& local_cache = *Parallel::local_branch(global_cache_proxy_); - SPECTRE_PARALLEL_REQUIRE("Nobody" == Parallel::get(local_cache)); + const std::string expected_name = + calls_to_test_one == 0 ? "Nobody" : "Somebody"; + SPECTRE_PARALLEL_REQUIRE(expected_name == Parallel::get(local_cache)); SPECTRE_PARALLEL_REQUIRE(178 == Parallel::get(local_cache)); SPECTRE_PARALLEL_REQUIRE(2.2 == Parallel::get(local_cache)); SPECTRE_PARALLEL_REQUIRE( @@ -261,7 +292,7 @@ void TestArrayChare::run_test_one() { serialize_and_deserialize( make_not_null(&serialized_and_deserialized_global_cache), local_cache); SPECTRE_PARALLEL_REQUIRE( - "Nobody" == + expected_name == Parallel::get(serialized_and_deserialized_global_cache)); SPECTRE_PARALLEL_REQUIRE( 178 == Parallel::get(serialized_and_deserialized_global_cache)); @@ -271,6 +302,57 @@ void TestArrayChare::run_test_one() { serialized_and_deserialized_global_cache) .number_of_sides()); + // Only register callbacks on the first call to `run_test_one` + if (calls_to_test_one == 0) { + auto callback = + CkCallback(CkIndex_TestArrayChare::run_test_one(), + this->thisProxy[this->thisIndex]); + const auto array_component_id = + Parallel::make_array_component_id>( + static_cast(this->thisIndex)); + + const auto register_callback = [&](const size_t index) { + Parallel::mutable_cache_item_is_ready( + local_cache, array_component_id, + [&callback, &index](const std::string& name_l) + -> std::unique_ptr { + return name_l == "Somebody" + ? std::unique_ptr{} + : std::unique_ptr( + new UseCkCallbackAsCallback(callback, index)); + }); + }; + + // Register first callback for this function + register_callback(0); + // Register second callback for this function with different index to test + // that we can have two callbacks on the same element + register_callback(1); + // Try and register the first callback again. This shouldn't be registered + // and we should still only have two callbacks + register_callback(0); + + // Do the mutate somehwere else so we can return here and have + // `run_test_one` be called again + this->thisProxy[0].mutate_name(); + + calls_to_test_one++; + return; + } else if (calls_to_test_one == 1) { + // Make sure we haven't mutated the weight yet, because that would move on + // to `run_test_two` then + SPECTRE_PARALLEL_REQUIRE(Parallel::get(local_cache) == 160.0); + + calls_to_test_one++; + return; + } else { + // We have now called both registered callbacks + SPECTRE_PARALLEL_REQUIRE(calls_to_test_one == 2); + + // The value will be checked in `run_test_two` + calls_to_test_one++; + } + // Mutate the weight to 150. Parallel::mutate>(local_cache, 150.0); } @@ -288,10 +370,14 @@ void TestArrayChare::run_test_two() { *Parallel::local_branch(global_cache_proxy_), array_component_id, [&callback]( const double& weight_l) -> std::unique_ptr { - return weight_l == 150 ? std::unique_ptr{} - : std::unique_ptr( - new UseCkCallbackAsCallback(callback)); + return weight_l == 150 + ? std::unique_ptr{} + : std::unique_ptr( + new UseCkCallbackAsCallback(callback, 0)); })) { + // One original call, and then two callbacks + SPECTRE_PARALLEL_REQUIRE(calls_to_test_one == 3); + auto& local_cache = *Parallel::local_branch(global_cache_proxy_); SPECTRE_PARALLEL_REQUIRE(150 == Parallel::get(local_cache)); @@ -319,7 +405,7 @@ void TestArrayChare::run_test_three() { return email_l == "albert@einstein.de" ? std::unique_ptr{} : std::unique_ptr( - new UseCkCallbackAsCallback(callback)); + new UseCkCallbackAsCallback(callback, 0)); })) { auto& local_cache = *Parallel::local_branch(global_cache_proxy_); SPECTRE_PARALLEL_REQUIRE("albert@einstein.de" == @@ -346,7 +432,7 @@ void TestArrayChare::run_test_four() { return animal_l.number_of_legs() == 8 ? std::unique_ptr{} : std::unique_ptr( - new UseCkCallbackAsCallback(callback)); + new UseCkCallbackAsCallback(callback, 0)); })) { auto& local_cache = *Parallel::local_branch(global_cache_proxy_); SPECTRE_PARALLEL_REQUIRE( @@ -373,7 +459,7 @@ void TestArrayChare::run_test_five() { return animal_l.number_of_legs() == 30 ? std::unique_ptr{} : std::unique_ptr( - new UseCkCallbackAsCallback(callback)); + new UseCkCallbackAsCallback(callback, 0)); })) { auto& local_cache = *Parallel::local_branch(global_cache_proxy_); SPECTRE_PARALLEL_REQUIRE( @@ -389,21 +475,21 @@ template void Test_GlobalCache::run_single_core_test() { using const_tag_list = typename Parallel::get_const_global_cache_tags; - static_assert(std::is_same_v>, - "Wrong const_tag_list in GlobalCache test"); + static_assert( + std::is_same_v>, + "Wrong const_tag_list in GlobalCache test"); using mutable_tag_list = typename Parallel::get_mutable_global_cache_tags; static_assert( - std::is_same_v>, + std::is_same_v>, "Wrong mutable_tag_list in GlobalCache test"); tuples::tagged_tuple_from_typelist const_data_to_be_cached( - "Nobody", 178, 2.2, std::make_unique()); + 178, 2.2, std::make_unique()); tuples::tagged_tuple_from_typelist mutable_data_to_be_cached(160, std::make_unique(6), - "joe@somewhere.com"); + "joe@somewhere.com", "Nobody"); Parallel::GlobalCache cache( std::move(const_data_to_be_cached), std::move(mutable_data_to_be_cached)); @@ -458,7 +544,7 @@ Test_GlobalCache::Test_GlobalCache(CkArgMsg* using mutable_tag_list = typename Parallel::get_mutable_global_cache_tags; static_assert( - std::is_same_v>, + std::is_same_v>, "Wrong mutable_tag_list in GlobalCache test"); static_assert( Parallel::is_in_mutable_global_cache); @@ -473,9 +559,9 @@ Test_GlobalCache::Test_GlobalCache(CkArgMsg* using const_tag_list = typename Parallel::get_const_global_cache_tags; - static_assert(std::is_same_v>, - "Wrong const_tag_list in GlobalCache test"); + static_assert( + std::is_same_v>, + "Wrong const_tag_list in GlobalCache test"); static_assert( Parallel::is_in_const_global_cache); static_assert( @@ -493,9 +579,9 @@ Test_GlobalCache::Test_GlobalCache(CkArgMsg* // Arthropod begins as an insect. tuples::tagged_tuple_from_typelist mutable_data_to_be_cached(160, std::make_unique(6), - "joe@somewhere.com"); + "joe@somewhere.com", "Nobody"); tuples::tagged_tuple_from_typelist const_data_to_be_cached( - "Nobody", 178, 2.2, std::make_unique()); + 178, 2.2, std::make_unique()); global_cache_proxy_ = Parallel::CProxy_GlobalCache::ckNew( std::move(const_data_to_be_cached), std::move(mutable_data_to_be_cached), std::nullopt); diff --git a/tests/Unit/Parallel/Test_GlobalCache.hpp b/tests/Unit/Parallel/Test_GlobalCache.hpp index c9ef7d6326bf..c286f9ebefd0 100644 --- a/tests/Unit/Parallel/Test_GlobalCache.hpp +++ b/tests/Unit/Parallel/Test_GlobalCache.hpp @@ -66,6 +66,7 @@ class TestArrayChare : public CBase_TestArrayChare { void run_test_three(); void run_test_four(); void run_test_five(); + void mutate_name(); private: CProxy_Test_GlobalCache main_proxy_;