kernel: Rework semaphore implementation

This commit is contained in:
IndecisiveTurtle 2024-06-26 17:57:18 +03:00
parent f489ecde86
commit 674db5e4d2

View file

@ -3,13 +3,14 @@
#include <condition_variable> #include <condition_variable>
#include <mutex> #include <mutex>
#include <utility>
#include <boost/intrusive/list.hpp> #include <boost/intrusive/list.hpp>
#include <pthread.h> #include <pthread.h>
#include "common/assert.h" #include "common/assert.h"
#include "common/logging/log.h" #include "common/logging/log.h"
#include "common/scope_exit.h"
#include "core/libraries/error_codes.h" #include "core/libraries/error_codes.h"
#include "core/libraries/libs.h" #include "core/libraries/libs.h"
#include "core/libraries/kernel/thread_management.h"
namespace Libraries::Kernel { namespace Libraries::Kernel {
@ -18,26 +19,30 @@ using ListBaseHook =
class Semaphore { class Semaphore {
public: public:
Semaphore(s32 init_count, s32 max_count, const char* name, bool is_fifo) Semaphore(s32 init_count, s32 max_count, std::string_view name, bool is_fifo)
: name{name}, token_count{init_count}, max_count{max_count}, is_fifo{is_fifo} {} : name{name}, token_count{init_count}, max_count{max_count},
init_count{init_count}, is_fifo{is_fifo} {}
~Semaphore() {
ASSERT(wait_list.empty());
}
bool Wait(bool can_block, s32 need_count, u64* timeout) { int Wait(bool can_block, s32 need_count, u32* timeout) {
if (HasAvailableTokens(need_count)) { std::unique_lock lk{mutex};
return true; if (token_count >= need_count) {
token_count -= need_count;
return ORBIS_OK;
} }
if (!can_block) { if (!can_block) {
return false; return ORBIS_KERNEL_ERROR_EBUSY;
} }
// Create waiting thread object and add it into the list of waiters. // Create waiting thread object and add it into the list of waiters.
WaitingThread waiter{need_count, is_fifo}; WaitingThread waiter{need_count, is_fifo};
AddWaiter(waiter); AddWaiter(waiter);
SCOPE_EXIT {
PopWaiter(waiter);
};
// Perform the wait. // Perform the wait.
return waiter.Wait(timeout); std::exchange(lk, std::unique_lock{waiter.mutex});
return waiter.Wait(lk, timeout);
} }
bool Signal(s32 signal_count) { bool Signal(s32 signal_count) {
@ -48,25 +53,47 @@ public:
token_count += signal_count; token_count += signal_count;
// Wake up threads in order of priority. // Wake up threads in order of priority.
for (auto& waiter : wait_list) { for (auto it = wait_list.begin(); it != wait_list.end();) {
auto& waiter = *it;
if (waiter.need_count > token_count) { if (waiter.need_count > token_count) {
it++;
continue; continue;
} }
std::scoped_lock lk2{waiter.mutex};
token_count -= waiter.need_count; token_count -= waiter.need_count;
waiter.cv.notify_one(); waiter.cv.notify_one();
it = wait_list.erase(it);
} }
return true; return true;
} }
private: int Cancel(s32 set_count, s32* num_waiters) {
std::scoped_lock lk{mutex};
if (num_waiters) {
*num_waiters = wait_list.size();
}
for (auto& waiter : wait_list) {
waiter.was_cancled = true;
waiter.cv.notify_one();
}
wait_list.clear();
token_count = set_count < 0 ? init_count : set_count;
return ORBIS_OK;
}
public:
struct WaitingThread : public ListBaseHook { struct WaitingThread : public ListBaseHook {
std::mutex mutex; std::mutex mutex;
std::string name;
std::condition_variable cv; std::condition_variable cv;
u32 priority; u32 priority;
s32 need_count; s32 need_count;
bool was_deleted{};
bool was_cancled{};
explicit WaitingThread(s32 need_count, bool is_fifo) : need_count{need_count} { explicit WaitingThread(s32 need_count, bool is_fifo) : need_count{need_count} {
name = scePthreadSelf()->name;
if (is_fifo) { if (is_fifo) {
return; return;
} }
@ -77,12 +104,24 @@ private:
priority = param.sched_priority; priority = param.sched_priority;
} }
bool Wait(u64* timeout) { int GetResult(bool timed_out) {
std::unique_lock lk{mutex}; if (timed_out) {
return SCE_KERNEL_ERROR_ETIMEDOUT;
}
if (was_deleted) {
return SCE_KERNEL_ERROR_EACCES;
}
if (was_cancled) {
return SCE_KERNEL_ERROR_ECANCELED;
}
return SCE_OK;
}
int Wait(std::unique_lock<std::mutex>& lk, u32* timeout) {
if (!timeout) { if (!timeout) {
// Wait indefinitely until we are woken up. // Wait indefinitely until we are woken up.
cv.wait(lk); cv.wait(lk);
return true; return GetResult(false);
} }
// Wait until timeout runs out, recording how much remaining time there was. // Wait until timeout runs out, recording how much remaining time there was.
const auto start = std::chrono::high_resolution_clock::now(); const auto start = std::chrono::high_resolution_clock::now();
@ -91,16 +130,11 @@ private:
const auto time = const auto time =
std::chrono::duration_cast<std::chrono::microseconds>(end - start).count(); std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
*timeout -= time; *timeout -= time;
return status != std::cv_status::timeout; return GetResult(status == std::cv_status::timeout);
}
bool operator<(const WaitingThread& other) const {
return priority < other.priority;
} }
}; };
void AddWaiter(WaitingThread& waiter) { void AddWaiter(WaitingThread& waiter) {
std::scoped_lock lk{mutex};
// Insert at the end of the list for FIFO order. // Insert at the end of the list for FIFO order.
if (is_fifo) { if (is_fifo) {
wait_list.push_back(waiter); wait_list.push_back(waiter);
@ -114,20 +148,6 @@ private:
wait_list.insert(it, waiter); wait_list.insert(it, waiter);
} }
void PopWaiter(WaitingThread& waiter) {
std::scoped_lock lk{mutex};
wait_list.erase(WaitingThreads::s_iterator_to(waiter));
}
bool HasAvailableTokens(s32 need_count) {
std::scoped_lock lk{mutex};
if (token_count >= need_count) {
token_count -= need_count;
return true;
}
return false;
}
using WaitingThreads = using WaitingThreads =
boost::intrusive::list<WaitingThread, boost::intrusive::base_hook<ListBaseHook>, boost::intrusive::list<WaitingThread, boost::intrusive::base_hook<ListBaseHook>,
boost::intrusive::constant_time_size<false>>; boost::intrusive::constant_time_size<false>>;
@ -136,6 +156,7 @@ private:
std::atomic<s32> token_count; std::atomic<s32> token_count;
std::mutex mutex; std::mutex mutex;
s32 max_count; s32 max_count;
s32 init_count;
bool is_fifo; bool is_fifo;
}; };
@ -151,9 +172,8 @@ s32 PS4_SYSV_ABI sceKernelCreateSema(OrbisKernelSema* sem, const char* pName, u3
return ORBIS_OK; return ORBIS_OK;
} }
s32 PS4_SYSV_ABI sceKernelWaitSema(OrbisKernelSema sem, s32 needCount, u64* pTimeout) { s32 PS4_SYSV_ABI sceKernelWaitSema(OrbisKernelSema sem, s32 needCount, u32* pTimeout) {
ASSERT(sem->Wait(true, needCount, pTimeout)); return sem->Wait(true, needCount, pTimeout);
return ORBIS_OK;
} }
s32 PS4_SYSV_ABI sceKernelSignalSema(OrbisKernelSema sem, s32 signalCount) { s32 PS4_SYSV_ABI sceKernelSignalSema(OrbisKernelSema sem, s32 signalCount) {
@ -164,9 +184,18 @@ s32 PS4_SYSV_ABI sceKernelSignalSema(OrbisKernelSema sem, s32 signalCount) {
} }
s32 PS4_SYSV_ABI sceKernelPollSema(OrbisKernelSema sem, s32 needCount) { s32 PS4_SYSV_ABI sceKernelPollSema(OrbisKernelSema sem, s32 needCount) {
if (!sem->Wait(false, needCount, nullptr)) { return sem->Wait(false, needCount, nullptr);
return ORBIS_KERNEL_ERROR_EBUSY; }
int PS4_SYSV_ABI sceKernelCancelSema(OrbisKernelSema sem, s32 setCount, s32 *pNumWaitThreads) {
return sem->Cancel(setCount, pNumWaitThreads);
}
int PS4_SYSV_ABI sceKernelDeleteSema(OrbisKernelSema sem) {
if (!sem) {
return SCE_KERNEL_ERROR_ESRCH;
} }
delete sem;
return ORBIS_OK; return ORBIS_OK;
} }
@ -175,6 +204,8 @@ void SemaphoreSymbolsRegister(Core::Loader::SymbolsResolver* sym) {
LIB_FUNCTION("Zxa0VhQVTsk", "libkernel", 1, "libkernel", 1, 1, sceKernelWaitSema); LIB_FUNCTION("Zxa0VhQVTsk", "libkernel", 1, "libkernel", 1, 1, sceKernelWaitSema);
LIB_FUNCTION("4czppHBiriw", "libkernel", 1, "libkernel", 1, 1, sceKernelSignalSema); LIB_FUNCTION("4czppHBiriw", "libkernel", 1, "libkernel", 1, 1, sceKernelSignalSema);
LIB_FUNCTION("12wOHk8ywb0", "libkernel", 1, "libkernel", 1, 1, sceKernelPollSema); LIB_FUNCTION("12wOHk8ywb0", "libkernel", 1, "libkernel", 1, 1, sceKernelPollSema);
LIB_FUNCTION("4DM06U2BNEY", "libkernel", 1, "libkernel", 1, 1, sceKernelCancelSema);
LIB_FUNCTION("R1Jvn8bSCW8", "libkernel", 1, "libkernel", 1, 1, sceKernelDeleteSema);
} }
} // namespace Libraries::Kernel } // namespace Libraries::Kernel