diff --git a/src/xrt/auxiliary/CMakeLists.txt b/src/xrt/auxiliary/CMakeLists.txt index 85b5367cd..ad6027081 100644 --- a/src/xrt/auxiliary/CMakeLists.txt +++ b/src/xrt/auxiliary/CMakeLists.txt @@ -127,6 +127,7 @@ set(UTIL_SOURCE_FILES util/u_format.h util/u_frame.c util/u_frame.h + util/u_generic_callbacks.hpp util/u_git_tag.h util/u_hand_tracking.c util/u_hand_tracking.h diff --git a/src/xrt/auxiliary/meson.build b/src/xrt/auxiliary/meson.build index 8433d434f..2d2e2bc0f 100644 --- a/src/xrt/auxiliary/meson.build +++ b/src/xrt/auxiliary/meson.build @@ -36,6 +36,7 @@ lib_aux_util = static_library( 'util/u_format.h', 'util/u_frame.c', 'util/u_frame.h', + 'util/u_generic_callbacks.hpp', 'util/u_git_tag.h', 'util/u_hand_tracking.c', 'util/u_hand_tracking.h', diff --git a/src/xrt/auxiliary/util/u_generic_callbacks.hpp b/src/xrt/auxiliary/util/u_generic_callbacks.hpp new file mode 100644 index 000000000..f3d6c13f7 --- /dev/null +++ b/src/xrt/auxiliary/util/u_generic_callbacks.hpp @@ -0,0 +1,231 @@ +// Copyright 2021, Collabora, Ltd. +// SPDX-License-Identifier: BSL-1.0 +/*! + * @file + * @brief Implementation of a generic callback collection, intended to be wrapped for a specific event type. + * @author Ryan Pavlik + * @ingroup aux_util + */ + +#include +#include + +namespace xrt::auxiliary::util { +template struct GenericCallbacks; + +namespace detail { + + /*! + * @brief Element type stored in @ref GenericCallbacks, for internal use only. + */ + template struct GenericCallbackEntry + { + CallbackType callback; + MaskType event_mask; + void *userdata; + bool should_remove = false; + + GenericCallbackEntry(CallbackType callback_, MaskType event_mask_, void *userdata_) noexcept + : callback(callback_), event_mask(event_mask_), userdata(userdata_) + {} + + /*! + * Do the two entries match? Used for removal "by value" + */ + bool + matches(GenericCallbackEntry const &other) const noexcept + { + return callback == other.callback && event_mask == other.event_mask && + userdata == other.userdata; + } + + bool + operator==(GenericCallbackEntry const &other) const noexcept + { + return matches(other); + } + + bool + shouldInvoke(MaskType event) const noexcept + { + return (event_mask & event) != 0; + } + }; + + template struct identity + { + using type = T; + }; + + // This lets us handle being passed an enum (which we can call underlying_type on) as well as an integer (which + // we cannot) + template + using mask_from_enum_t = + typename std::conditional_t::value, std::underlying_type, identity>::type; + +} // namespace detail + +/*! + * @brief A generic collection of callbacks for event types represented as a bitmask, intended to be wrapped for each + * usage. + * + * A registered callback may identify one or more event types (bits in the bitmask) that it wants to be invoked for. A + * userdata void pointer is also stored for each callback. Bitmasks are tested at invocation time, and the general + * callback format allows for callbacks to indicate they should be removed from the collection. Actually calling each + * callback is left to a consumer-provided "invoker" to allow adding context and event data to the call. The "invoker" + * also allows the option of whether or how to expose the self-removal capability: yours might simply always return + * "false". + * + * This generic structure supports callbacks that are included multiple times in the collection, if the consuming code + * needs it. GenericCallbacks::contains may be used by consuming code before conditionally calling addCallback, to + * limit to a single instance in a collection. + * + * @tparam CallbackType the function pointer type to store for each callback. + * @tparam EventType the event enum type. + */ +template struct GenericCallbacks +{ + +public: + static_assert(std::is_integral::value || std::is_enum::value, + "Your event type must either be an integer or an enum"); + using callback_t = CallbackType; + using event_t = EventType; + using mask_t = detail::mask_from_enum_t; + +private: + static_assert(std::is_integral::value, "Our enum to mask conversion should have produced an integer"); + + //! The type stored for each added callback. + using callback_entry_t = detail::GenericCallbackEntry; + +public: + /*! + * @brief Add a new callback entry with the given callback function pointer, event mask, and user data. + * + * New callback entries are always added at the end of the collection. + */ + void + addCallback(CallbackType callback, mask_t event_mask, void *userdata) + { + callbacks.emplace_back(callback, event_mask, userdata); + } + + /*! + * @brief Remove some number of callback entries matching the given callback function pointer, event mask, and + * user data. + * + * @param callback The callback function pointer. Tested for equality with each callback entry. + * @param event_mask The callback event mask. Tested for equality with each callback entry. + * @param userdata The opaque user data pointer. Tested for equality with each callback entry. + * @param num_skip The number of matches to skip before starting to remove callbacks. Defaults to 0. + * @param max_remove The number of matches to remove, or negative if no limit. Defaults to -1. + * + * @returns the number of callbacks removed. + */ + int + removeCallback( + CallbackType callback, mask_t event_mask, void *userdata, unsigned int num_skip = 0, int max_remove = -1) + { + if (max_remove == 0) { + // We were told to remove none. We can do this very quickly. + // Avoids a corner case in the loop where we assume max_remove is non-zero. + return 0; + } + bool found = false; + + const callback_entry_t needle{callback, event_mask, userdata}; + for (auto &entry : callbacks) { + if (entry.matches(needle)) { + if (num_skip > 0) { + // We are still in our skipping phase. + num_skip--; + continue; + } + entry.should_remove = true; + found = true; + // Negatives (no max) get more negative, which is OK. + max_remove--; + if (max_remove == 0) { + // not looking for more + break; + } + } + } + if (found) { + return purgeMarkedCallbacks(); + } + // if we didn't find any, we removed zero. + return 0; + } + + /*! + * @brief See if the collection contains at least one matching callback. + * + * @param callback The callback function pointer. Tested for equality with each callback entry. + * @param event_mask The callback event mask. Tested for equality with each callback entry. + * @param userdata The opaque user data pointer. Tested for equality with each callback entry. + * + * @returns true if a matching callback is found. + */ + bool + contains(CallbackType callback, mask_t event_mask, void *userdata) + { + const callback_entry_t needle{callback, event_mask, userdata}; + auto it = std::find(callbacks.begin(), callbacks.end(), needle); + return it != callbacks.end(); + } + + /*! + * @brief Invokes the callbacks, by passing the ones we should run to your "invoker" to add any desired + * context/event data and forward the call. + * + * Callbacks are called in order, filtering out those whose event mask does not include the given event. + * + * @param event The event type to invoke callbacks for. + * @param invoker A function/functor accepting the event, a callback function pointer, and the callback entry's + * userdata as parameters, and returning true if the callback should be removed from the collection. It is + * assumed that the invoker will add any additional context or event data and call the provided callback. + * + * Typically, a lambda with some captures and a single return statement will be sufficient for an invoker. + * + * @returns the number of callbacks run + */ + template + int + invokeCallbacks(EventType event, F &&invoker) + { + bool needPurge = false; + + int ran = 0; + for (auto &entry : callbacks) { + if (entry.shouldInvoke(static_cast(event))) { + bool willRemove = invoker(event, entry.callback, entry.userdata); + if (willRemove) { + entry.should_remove = true; + needPurge = true; + } + ran++; + } + } + if (needPurge) { + purgeMarkedCallbacks(); + } + return ran; + } + +private: + std::vector callbacks; + + int + purgeMarkedCallbacks() + { + auto b = callbacks.begin(); + auto e = callbacks.end(); + auto new_end = std::remove_if(b, e, [](callback_entry_t const &entry) { return entry.should_remove; }); + auto num_removed = std::distance(new_end, e); + callbacks.erase(new_end, e); + return static_cast(num_removed); + } +}; +} // namespace xrt::auxiliary::util diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8b71246d9..f261fb69f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -17,3 +17,9 @@ target_link_libraries(tests_input_transform PRIVATE xrt-external-openxr aux_util) add_test(NAME input_transform COMMAND tests_input_transform --success) + +# Generic callbacks +add_executable(tests_generic_callbacks tests_generic_callbacks.cpp) +target_link_libraries(tests_generic_callbacks PRIVATE tests_main) +target_link_libraries(tests_generic_callbacks PRIVATE aux_util) +add_test(NAME tests_generic_callbacks COMMAND tests_generic_callbacks --success) diff --git a/tests/tests_generic_callbacks.cpp b/tests/tests_generic_callbacks.cpp new file mode 100644 index 000000000..e3997723a --- /dev/null +++ b/tests/tests_generic_callbacks.cpp @@ -0,0 +1,104 @@ +// Copyright 2021, Collabora, Ltd. +// SPDX-License-Identifier: BSL-1.0 +/*! + * @file + * @brief Generic callback collection tests. + * @author Ryan Pavlik + */ + +#include "catch/catch.hpp" + +#include + +using xrt::auxiliary::util::GenericCallbacks; + +enum class MyEvent +{ + ACQUIRED, + LOST, +}; + +using mask_t = std::underlying_type_t; + +static bool +increment_userdata_int(MyEvent event, void *userdata) +{ + *static_cast(userdata) += 1; + return true; +} + + +using callback_t = bool (*)(MyEvent event, void *userdata); + +TEST_CASE("u_generic_callbacks") +{ + GenericCallbacks callbacks; + // Simplest possible invoker. + auto invoker = [](MyEvent event, callback_t callback, void *userdata) { return callback(event, userdata); }; + + SECTION("call when empty") + { + CHECK(0 == callbacks.invokeCallbacks(MyEvent::ACQUIRED, invoker)); + CHECK(0 == callbacks.invokeCallbacks(MyEvent::LOST, invoker)); + CHECK(0 == callbacks.removeCallback(&increment_userdata_int, (mask_t)MyEvent::LOST, nullptr)); + } + SECTION("same function, different mask and userdata") + { + int numAcquired = 0; + int numLost = 0; + REQUIRE_NOTHROW(callbacks.addCallback(increment_userdata_int, (mask_t)MyEvent::ACQUIRED, &numAcquired)); + REQUIRE_NOTHROW(callbacks.addCallback(increment_userdata_int, (mask_t)MyEvent::LOST, &numLost)); + SECTION("contains") + { + CHECK(callbacks.contains(increment_userdata_int, (mask_t)MyEvent::LOST, &numLost)); + CHECK_FALSE(callbacks.contains(increment_userdata_int, (mask_t)MyEvent::LOST, &numAcquired)); + } + SECTION("removal matching") + { + CHECK(0 == + callbacks.removeCallback(increment_userdata_int, (mask_t)MyEvent::LOST, &numAcquired)); + CHECK(0 == + callbacks.removeCallback(increment_userdata_int, (mask_t)MyEvent::ACQUIRED, &numLost)); + } + SECTION("duplicates, contains, and removal") + { + REQUIRE(callbacks.contains(increment_userdata_int, (mask_t)MyEvent::ACQUIRED, &numAcquired)); + REQUIRE_NOTHROW( + callbacks.addCallback(increment_userdata_int, (mask_t)MyEvent::ACQUIRED, &numAcquired)); + CHECK(callbacks.contains(increment_userdata_int, (mask_t)MyEvent::ACQUIRED, &numAcquired)); + // Now we have two ACQUIRED and one LOST callback. + SECTION("max_remove") + { + CHECK(0 == callbacks.removeCallback(increment_userdata_int, (mask_t)MyEvent::ACQUIRED, + &numAcquired, 0, 0)); + CHECK(callbacks.contains(increment_userdata_int, (mask_t)MyEvent::ACQUIRED, + &numAcquired)); + + CHECK(1 == callbacks.removeCallback(increment_userdata_int, (mask_t)MyEvent::ACQUIRED, + &numAcquired, 0, 1)); + CHECK(callbacks.contains(increment_userdata_int, (mask_t)MyEvent::ACQUIRED, + &numAcquired)); + } + SECTION("large max_remove") + { + CHECK(2 == callbacks.removeCallback(increment_userdata_int, (mask_t)MyEvent::ACQUIRED, + &numAcquired, 0, 3)); + CHECK_FALSE(callbacks.contains(increment_userdata_int, (mask_t)MyEvent::ACQUIRED, + &numAcquired)); + } + SECTION("num_skip") + { + CHECK(0 == callbacks.removeCallback(increment_userdata_int, (mask_t)MyEvent::ACQUIRED, + &numAcquired, 3)); + CHECK(callbacks.contains(increment_userdata_int, (mask_t)MyEvent::ACQUIRED, + &numAcquired)); + + CHECK(1 == callbacks.removeCallback(increment_userdata_int, (mask_t)MyEvent::ACQUIRED, + &numAcquired, 1)); + CHECK(callbacks.contains(increment_userdata_int, (mask_t)MyEvent::ACQUIRED, + &numAcquired)); + } + } + CHECK(1 == callbacks.removeCallback(increment_userdata_int, (mask_t)MyEvent::LOST, &numLost)); + } +}