#pragma once
#include <vector>
#include <map>
#include <algorithm>
#include "Common/CommonTypes.h"
#include "Core/CoreTiming.h"
#include "Core/HLE/sceKernelThread.h"
#include "Core/HLE/ErrorCodes.h"
namespace HLEKernel
{
template <typename KO, WaitType waitType>
inline void WaitExecTimeout(SceUID threadID) {
u32 error;
SceUID uid = __KernelGetWaitID(threadID, waitType, error);
u32 timeoutPtr = __KernelGetWaitTimeoutPtr(threadID, error);
KO *ko = uid == 0 ? NULL : kernelObjects.Get<KO>(uid, error);
if (ko)
{
if (timeoutPtr != 0)
Memory::Write_U32(0, timeoutPtr);
__KernelResumeThreadFromWait(threadID, SCE_KERNEL_ERROR_WAIT_TIMEOUT);
__KernelReSchedule("wait timed out");
}
}
template <typename WaitInfoType, typename PauseType>
inline bool WaitPauseHelperUpdate(SceUID pauseKey, SceUID threadID, std::vector<WaitInfoType> &waitingThreads, std::map<SceUID, PauseType> &pausedWaits, u64 pauseTimeout) {
WaitInfoType waitData = {0};
for (size_t i = 0; i < waitingThreads.size(); i++) {
WaitInfoType *t = &waitingThreads[i];
if (t->threadID == threadID)
{
waitData = *t;
waitingThreads.erase(waitingThreads.begin() + i);
break;
}
}
if (waitData.threadID != threadID)
return false;
waitData.pausedTimeout = pauseTimeout;
pausedWaits[pauseKey] = waitData;
return true;
}
template <>
inline bool WaitPauseHelperUpdate<SceUID, u64>(SceUID pauseKey, SceUID threadID, std::vector<SceUID> &waitingThreads, std::map<SceUID, u64> &pausedWaits, u64 pauseTimeout) {
waitingThreads.erase(std::remove(waitingThreads.begin(), waitingThreads.end(), threadID), waitingThreads.end());
pausedWaits[pauseKey] = pauseTimeout;
return true;
}
template <typename WaitInfoType, typename PauseType>
inline u64 WaitPauseHelperGet(SceUID pauseKey, SceUID threadID, std::map<SceUID, PauseType> &pausedWaits, WaitInfoType &waitData) {
waitData = pausedWaits[pauseKey];
u64 waitDeadline = waitData.pausedTimeout;
pausedWaits.erase(pauseKey);
return waitDeadline;
}
template <>
inline u64 WaitPauseHelperGet<SceUID, u64>(SceUID pauseKey, SceUID threadID, std::map<SceUID, u64> &pausedWaits, SceUID &waitData) {
waitData = threadID;
u64 waitDeadline = pausedWaits[pauseKey];
pausedWaits.erase(pauseKey);
return waitDeadline;
}
enum WaitBeginEndCallbackResult {
WAIT_CB_BAD_WAIT_DATA = -2,
WAIT_CB_BAD_WAIT_ID = -1,
WAIT_CB_SUCCESS = 0,
WAIT_CB_RESUMED_WAIT = 1,
WAIT_CB_TIMED_OUT = 2,
};
template <typename WaitInfoType, typename PauseType>
WaitBeginEndCallbackResult WaitBeginCallback(SceUID threadID, SceUID prevCallbackId, int waitTimer, std::vector<WaitInfoType> &waitingThreads, std::map<SceUID, PauseType> &pausedWaits, bool doTimeout = true) {
SceUID pauseKey = prevCallbackId == 0 ? threadID : prevCallbackId;
if (pausedWaits.find(pauseKey) != pausedWaits.end()) {
return WAIT_CB_SUCCESS;
}
u64 pausedTimeout = 0;
if (doTimeout && waitTimer != -1) {
s64 cyclesLeft = CoreTiming::UnscheduleEvent(waitTimer, threadID);
pausedTimeout = CoreTiming::GetTicks() + cyclesLeft;
}
if (!WaitPauseHelperUpdate(pauseKey, threadID, waitingThreads, pausedWaits, pausedTimeout)) {
return WAIT_CB_BAD_WAIT_DATA;
}
return WAIT_CB_SUCCESS;
}
template <typename KO, WaitType waitType, typename WaitInfoType>
WaitBeginEndCallbackResult WaitBeginCallback(SceUID threadID, SceUID prevCallbackId, int waitTimer) {
u32 error;
SceUID uid = __KernelGetWaitID(threadID, waitType, error);
u32 timeoutPtr = __KernelGetWaitTimeoutPtr(threadID, error);
KO *ko = uid == 0 ? NULL : kernelObjects.Get<KO>(uid, error);
if (ko) {
return WaitBeginCallback(threadID, prevCallbackId, waitTimer, ko->waitingThreads, ko->pausedWaits, timeoutPtr != 0);
} else {
return WAIT_CB_BAD_WAIT_ID;
}
}
template <typename KO, WaitType waitType, typename WaitInfoType, typename PauseType, class TryUnlockFunc>
WaitBeginEndCallbackResult WaitEndCallback(SceUID threadID, SceUID prevCallbackId, int waitTimer, TryUnlockFunc TryUnlock, WaitInfoType &waitData, std::vector<WaitInfoType> &waitingThreads, std::map<SceUID, PauseType> &pausedWaits) {
SceUID pauseKey = prevCallbackId == 0 ? threadID : prevCallbackId;
u32 error;
SceUID uid = __KernelGetWaitID(threadID, waitType, error);
u32 timeoutPtr = __KernelGetWaitTimeoutPtr(threadID, error);
KO *ko = uid == 0 ? NULL : kernelObjects.Get<KO>(uid, error);
if (!ko || pausedWaits.find(pauseKey) == pausedWaits.end()) {
if (timeoutPtr != 0 && waitTimer != -1) {
Memory::Write_U32(0, timeoutPtr);
}
__KernelResumeThreadFromWait(threadID, SCE_KERNEL_ERROR_WAIT_DELETE);
return WAIT_CB_SUCCESS;
}
u64 waitDeadline = WaitPauseHelperGet(pauseKey, threadID, pausedWaits, waitData);
bool wokeThreads;
if (TryUnlock(ko, waitData, error, 0, wokeThreads)) {
return WAIT_CB_SUCCESS;
}
s64 cyclesLeft = waitDeadline - CoreTiming::GetTicks();
if (cyclesLeft < 0 && waitDeadline != 0) {
if (timeoutPtr != 0 && waitTimer != -1) {
Memory::Write_U32(0, timeoutPtr);
}
__KernelResumeThreadFromWait(threadID, SCE_KERNEL_ERROR_WAIT_TIMEOUT);
return WAIT_CB_TIMED_OUT;
} else {
if (timeoutPtr != 0 && waitTimer != -1) {
CoreTiming::ScheduleEvent(cyclesLeft, waitTimer, __KernelGetCurThread());
}
return WAIT_CB_RESUMED_WAIT;
}
}
template <typename KO, WaitType waitType, typename WaitInfoType, class TryUnlockFunc>
WaitBeginEndCallbackResult WaitEndCallback(SceUID threadID, SceUID prevCallbackId, int waitTimer, TryUnlockFunc TryUnlock) {
u32 error;
SceUID uid = __KernelGetWaitID(threadID, waitType, error);
u32 timeoutPtr = __KernelGetWaitTimeoutPtr(threadID, error);
KO *ko = uid == 0 ? NULL : kernelObjects.Get<KO>(uid, error);
if (!ko) {
if (timeoutPtr != 0 && waitTimer != -1) {
Memory::Write_U32(0, timeoutPtr);
}
__KernelResumeThreadFromWait(threadID, SCE_KERNEL_ERROR_WAIT_DELETE);
return WAIT_CB_SUCCESS;
}
WaitInfoType waitData;
auto result = WaitEndCallback<KO, waitType>(threadID, prevCallbackId, waitTimer, TryUnlock, waitData, ko->waitingThreads, ko->pausedWaits);
if (result == WAIT_CB_RESUMED_WAIT) {
ko->waitingThreads.push_back(waitData);
}
return result;
}
template <typename T>
inline bool VerifyWait(const T &waitInfo, WaitType waitType, SceUID uid) {
u32 error;
SceUID waitID = __KernelGetWaitID(waitInfo.threadID, waitType, error);
return waitID == uid && error == 0;
}
template <>
inline bool VerifyWait(const SceUID &threadID, WaitType waitType, SceUID uid) {
u32 error;
SceUID waitID = __KernelGetWaitID(threadID, waitType, error);
return waitID == uid && error == 0;
}
template <typename T>
inline bool ResumeFromWait(SceUID threadID, WaitType waitType, SceUID uid, T result) {
if (VerifyWait(threadID, waitType, uid)) {
__KernelResumeThreadFromWait(threadID, result);
return true;
}
return false;
}
template <typename T>
inline void CleanupWaitingThreads(WaitType waitType, SceUID uid, std::vector<T> &waitingThreads) {
size_t size = waitingThreads.size();
for (size_t i = 0; i < size; ++i) {
if (!VerifyWait(waitingThreads[i], waitType, uid)) {
if (--size != i) {
std::swap(waitingThreads[i], waitingThreads[size]);
}
--i;
}
}
waitingThreads.resize(size);
}
template <typename T>
inline void RemoveWaitingThread(std::vector<T> &waitingThreads, const SceUID threadID) {
waitingThreads.erase(std::remove(waitingThreads.begin(), waitingThreads.end(), threadID), waitingThreads.end());
}
};