From 674db5e4d2748da26acadb54f3570650e2e0b4d9 Mon Sep 17 00:00:00 2001 From: IndecisiveTurtle <47210458+raphaelthegreat@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:57:18 +0300 Subject: [PATCH] kernel: Rework semaphore implementation --- .../libraries/kernel/threads/semaphore.cpp | 113 +++++++++++------- 1 file changed, 72 insertions(+), 41 deletions(-) diff --git a/src/core/libraries/kernel/threads/semaphore.cpp b/src/core/libraries/kernel/threads/semaphore.cpp index e6fc667d2..63c33de72 100644 --- a/src/core/libraries/kernel/threads/semaphore.cpp +++ b/src/core/libraries/kernel/threads/semaphore.cpp @@ -3,13 +3,14 @@ #include #include +#include #include #include #include "common/assert.h" #include "common/logging/log.h" -#include "common/scope_exit.h" #include "core/libraries/error_codes.h" #include "core/libraries/libs.h" +#include "core/libraries/kernel/thread_management.h" namespace Libraries::Kernel { @@ -18,26 +19,30 @@ using ListBaseHook = class Semaphore { public: - Semaphore(s32 init_count, s32 max_count, const char* name, bool is_fifo) - : name{name}, token_count{init_count}, max_count{max_count}, is_fifo{is_fifo} {} + Semaphore(s32 init_count, s32 max_count, std::string_view name, bool is_fifo) + : name{name}, token_count{init_count}, max_count{max_count}, + init_count{init_count}, is_fifo{is_fifo} {} + ~Semaphore() { + ASSERT(wait_list.empty()); + } - bool Wait(bool can_block, s32 need_count, u64* timeout) { - if (HasAvailableTokens(need_count)) { - return true; + int Wait(bool can_block, s32 need_count, u32* timeout) { + std::unique_lock lk{mutex}; + if (token_count >= need_count) { + token_count -= need_count; + return ORBIS_OK; } if (!can_block) { - return false; + return ORBIS_KERNEL_ERROR_EBUSY; } // Create waiting thread object and add it into the list of waiters. WaitingThread waiter{need_count, is_fifo}; AddWaiter(waiter); - SCOPE_EXIT { - PopWaiter(waiter); - }; // Perform the wait. - return waiter.Wait(timeout); + std::exchange(lk, std::unique_lock{waiter.mutex}); + return waiter.Wait(lk, timeout); } bool Signal(s32 signal_count) { @@ -48,25 +53,47 @@ public: token_count += signal_count; // Wake up threads in order of priority. - for (auto& waiter : wait_list) { + for (auto it = wait_list.begin(); it != wait_list.end();) { + auto& waiter = *it; if (waiter.need_count > token_count) { + it++; continue; } + std::scoped_lock lk2{waiter.mutex}; token_count -= waiter.need_count; waiter.cv.notify_one(); + it = wait_list.erase(it); } return true; } -private: + int Cancel(s32 set_count, s32* num_waiters) { + std::scoped_lock lk{mutex}; + if (num_waiters) { + *num_waiters = wait_list.size(); + } + for (auto& waiter : wait_list) { + waiter.was_cancled = true; + waiter.cv.notify_one(); + } + wait_list.clear(); + token_count = set_count < 0 ? init_count : set_count; + return ORBIS_OK; + } + +public: struct WaitingThread : public ListBaseHook { std::mutex mutex; + std::string name; std::condition_variable cv; u32 priority; s32 need_count; + bool was_deleted{}; + bool was_cancled{}; explicit WaitingThread(s32 need_count, bool is_fifo) : need_count{need_count} { + name = scePthreadSelf()->name; if (is_fifo) { return; } @@ -77,12 +104,24 @@ private: priority = param.sched_priority; } - bool Wait(u64* timeout) { - std::unique_lock lk{mutex}; + int GetResult(bool timed_out) { + if (timed_out) { + return SCE_KERNEL_ERROR_ETIMEDOUT; + } + if (was_deleted) { + return SCE_KERNEL_ERROR_EACCES; + } + if (was_cancled) { + return SCE_KERNEL_ERROR_ECANCELED; + } + return SCE_OK; + } + + int Wait(std::unique_lock& lk, u32* timeout) { if (!timeout) { // Wait indefinitely until we are woken up. cv.wait(lk); - return true; + return GetResult(false); } // Wait until timeout runs out, recording how much remaining time there was. const auto start = std::chrono::high_resolution_clock::now(); @@ -91,16 +130,11 @@ private: const auto time = std::chrono::duration_cast(end - start).count(); *timeout -= time; - return status != std::cv_status::timeout; - } - - bool operator<(const WaitingThread& other) const { - return priority < other.priority; + return GetResult(status == std::cv_status::timeout); } }; void AddWaiter(WaitingThread& waiter) { - std::scoped_lock lk{mutex}; // Insert at the end of the list for FIFO order. if (is_fifo) { wait_list.push_back(waiter); @@ -114,20 +148,6 @@ private: wait_list.insert(it, waiter); } - void PopWaiter(WaitingThread& waiter) { - std::scoped_lock lk{mutex}; - wait_list.erase(WaitingThreads::s_iterator_to(waiter)); - } - - bool HasAvailableTokens(s32 need_count) { - std::scoped_lock lk{mutex}; - if (token_count >= need_count) { - token_count -= need_count; - return true; - } - return false; - } - using WaitingThreads = boost::intrusive::list, boost::intrusive::constant_time_size>; @@ -136,6 +156,7 @@ private: std::atomic token_count; std::mutex mutex; s32 max_count; + s32 init_count; bool is_fifo; }; @@ -151,9 +172,8 @@ s32 PS4_SYSV_ABI sceKernelCreateSema(OrbisKernelSema* sem, const char* pName, u3 return ORBIS_OK; } -s32 PS4_SYSV_ABI sceKernelWaitSema(OrbisKernelSema sem, s32 needCount, u64* pTimeout) { - ASSERT(sem->Wait(true, needCount, pTimeout)); - return ORBIS_OK; +s32 PS4_SYSV_ABI sceKernelWaitSema(OrbisKernelSema sem, s32 needCount, u32* pTimeout) { + return sem->Wait(true, needCount, pTimeout); } s32 PS4_SYSV_ABI sceKernelSignalSema(OrbisKernelSema sem, s32 signalCount) { @@ -164,9 +184,18 @@ s32 PS4_SYSV_ABI sceKernelSignalSema(OrbisKernelSema sem, s32 signalCount) { } s32 PS4_SYSV_ABI sceKernelPollSema(OrbisKernelSema sem, s32 needCount) { - if (!sem->Wait(false, needCount, nullptr)) { - return ORBIS_KERNEL_ERROR_EBUSY; + return sem->Wait(false, needCount, nullptr); +} + +int PS4_SYSV_ABI sceKernelCancelSema(OrbisKernelSema sem, s32 setCount, s32 *pNumWaitThreads) { + return sem->Cancel(setCount, pNumWaitThreads); +} + +int PS4_SYSV_ABI sceKernelDeleteSema(OrbisKernelSema sem) { + if (!sem) { + return SCE_KERNEL_ERROR_ESRCH; } + delete sem; return ORBIS_OK; } @@ -175,6 +204,8 @@ void SemaphoreSymbolsRegister(Core::Loader::SymbolsResolver* sym) { LIB_FUNCTION("Zxa0VhQVTsk", "libkernel", 1, "libkernel", 1, 1, sceKernelWaitSema); LIB_FUNCTION("4czppHBiriw", "libkernel", 1, "libkernel", 1, 1, sceKernelSignalSema); LIB_FUNCTION("12wOHk8ywb0", "libkernel", 1, "libkernel", 1, 1, sceKernelPollSema); + LIB_FUNCTION("4DM06U2BNEY", "libkernel", 1, "libkernel", 1, 1, sceKernelCancelSema); + LIB_FUNCTION("R1Jvn8bSCW8", "libkernel", 1, "libkernel", 1, 1, sceKernelDeleteSema); } } // namespace Libraries::Kernel