Kernel: Object ShouldWait and Acquire calls now take a thread as a parameter.

This will be useful when implementing mutex priority inheritance.
This commit is contained in:
Subv 2017-01-01 16:53:22 -05:00
commit 90570c153b
17 changed files with 56 additions and 68 deletions

View file

@ -30,12 +30,12 @@ SharedPtr<Event> Event::Create(ResetType reset_type, std::string name) {
return evt; return evt;
} }
bool Event::ShouldWait() { bool Event::ShouldWait(Thread* thread) const {
return !signaled; return !signaled;
} }
void Event::Acquire() { void Event::Acquire(Thread* thread) {
ASSERT_MSG(!ShouldWait(), "object unavailable!"); ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
// Release the event if it's not sticky... // Release the event if it's not sticky...
if (reset_type != ResetType::Sticky) if (reset_type != ResetType::Sticky)

View file

@ -35,8 +35,8 @@ public:
bool signaled; ///< Whether the event has already been signaled bool signaled; ///< Whether the event has already been signaled
std::string name; ///< Name of event (optional) std::string name; ///< Name of event (optional)
bool ShouldWait() override; bool ShouldWait(Thread* thread) const override;
void Acquire() override; void Acquire(Thread* thread) override;
void Signal(); void Signal();
void Clear(); void Clear();

View file

@ -39,11 +39,6 @@ SharedPtr<Thread> WaitObject::GetHighestPriorityReadyThread() {
thread->status == THREADSTATUS_DEAD; thread->status == THREADSTATUS_DEAD;
}); });
// TODO(Subv): This call should be performed inside the loop below to check if an object can be
// acquired by a particular thread. This is useful for things like recursive locking of Mutexes.
if (ShouldWait())
return nullptr;
Thread* candidate = nullptr; Thread* candidate = nullptr;
s32 candidate_priority = THREADPRIO_LOWEST + 1; s32 candidate_priority = THREADPRIO_LOWEST + 1;
@ -51,9 +46,12 @@ SharedPtr<Thread> WaitObject::GetHighestPriorityReadyThread() {
if (thread->current_priority >= candidate_priority) if (thread->current_priority >= candidate_priority)
continue; continue;
if (ShouldWait(thread.get()))
continue;
bool ready_to_run = bool ready_to_run =
std::none_of(thread->wait_objects.begin(), thread->wait_objects.end(), std::none_of(thread->wait_objects.begin(), thread->wait_objects.end(),
[](const SharedPtr<WaitObject>& object) { return object->ShouldWait(); }); [&thread](const SharedPtr<WaitObject>& object) { return object->ShouldWait(thread.get()); });
if (ready_to_run) { if (ready_to_run) {
candidate = thread.get(); candidate = thread.get();
candidate_priority = thread->current_priority; candidate_priority = thread->current_priority;
@ -66,7 +64,7 @@ SharedPtr<Thread> WaitObject::GetHighestPriorityReadyThread() {
void WaitObject::WakeupAllWaitingThreads() { void WaitObject::WakeupAllWaitingThreads() {
while (auto thread = GetHighestPriorityReadyThread()) { while (auto thread = GetHighestPriorityReadyThread()) {
if (!thread->IsSleepingOnWaitAll()) { if (!thread->IsSleepingOnWaitAll()) {
Acquire(); Acquire(thread.get());
// Set the output index of the WaitSynchronizationN call to the index of this object. // Set the output index of the WaitSynchronizationN call to the index of this object.
if (thread->wait_set_output) { if (thread->wait_set_output) {
thread->SetWaitSynchronizationOutput(thread->GetWaitObjectIndex(this)); thread->SetWaitSynchronizationOutput(thread->GetWaitObjectIndex(this));
@ -74,7 +72,7 @@ void WaitObject::WakeupAllWaitingThreads() {
} }
} else { } else {
for (auto& object : thread->wait_objects) { for (auto& object : thread->wait_objects) {
object->Acquire(); object->Acquire(thread.get());
object->RemoveWaitingThread(thread.get()); object->RemoveWaitingThread(thread.get());
} }
// Note: This case doesn't update the output index of WaitSynchronizationN. // Note: This case doesn't update the output index of WaitSynchronizationN.

View file

@ -132,13 +132,14 @@ using SharedPtr = boost::intrusive_ptr<T>;
class WaitObject : public Object { class WaitObject : public Object {
public: public:
/** /**
* Check if the current thread should wait until the object is available * Check if the specified thread should wait until the object is available
* @param thread The thread about which we're deciding.
* @return True if the current thread should wait due to this object being unavailable * @return True if the current thread should wait due to this object being unavailable
*/ */
virtual bool ShouldWait() = 0; virtual bool ShouldWait(Thread* thread) const = 0;
/// Acquire/lock the object if it is available /// Acquire/lock the object for the specified thread if it is available
virtual void Acquire() = 0; virtual void Acquire(Thread* thread) = 0;
/** /**
* Add a thread to wait on this object * Add a thread to wait on this object

View file

@ -40,31 +40,19 @@ SharedPtr<Mutex> Mutex::Create(bool initial_locked, std::string name) {
mutex->name = std::move(name); mutex->name = std::move(name);
mutex->holding_thread = nullptr; mutex->holding_thread = nullptr;
// Acquire mutex with current thread if initialized as locked... // Acquire mutex with current thread if initialized as locked
if (initial_locked) if (initial_locked)
mutex->Acquire(); mutex->Acquire(GetCurrentThread());
return mutex; return mutex;
} }
bool Mutex::ShouldWait() { bool Mutex::ShouldWait(Thread* thread) const {
auto thread = GetCurrentThread(); return lock_count > 0 && thread != holding_thread;
bool wait = lock_count > 0 && holding_thread != thread;
// If the holding thread of the mutex is lower priority than this thread, that thread should
// temporarily inherit this thread's priority
if (wait && thread->current_priority < holding_thread->current_priority)
holding_thread->BoostPriority(thread->current_priority);
return wait;
} }
void Mutex::Acquire() { void Mutex::Acquire(Thread* thread) {
Acquire(GetCurrentThread()); ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
}
void Mutex::Acquire(SharedPtr<Thread> thread) {
ASSERT_MSG(!ShouldWait(), "object unavailable!");
// Actually "acquire" the mutex only if we don't already have it... // Actually "acquire" the mutex only if we don't already have it...
if (lock_count == 0) { if (lock_count == 0) {

View file

@ -38,8 +38,9 @@ public:
std::string name; ///< Name of mutex (optional) std::string name; ///< Name of mutex (optional)
SharedPtr<Thread> holding_thread; ///< Thread that has acquired the mutex SharedPtr<Thread> holding_thread; ///< Thread that has acquired the mutex
bool ShouldWait() override; bool ShouldWait(Thread* thread) const override;
void Acquire() override; void Acquire(Thread* thread) override;
/** /**
* Acquires the specified mutex for the specified thread * Acquires the specified mutex for the specified thread

View file

@ -30,12 +30,12 @@ ResultVal<SharedPtr<Semaphore>> Semaphore::Create(s32 initial_count, s32 max_cou
return MakeResult<SharedPtr<Semaphore>>(std::move(semaphore)); return MakeResult<SharedPtr<Semaphore>>(std::move(semaphore));
} }
bool Semaphore::ShouldWait() { bool Semaphore::ShouldWait(Thread* thread) const {
return available_count <= 0; return available_count <= 0;
} }
void Semaphore::Acquire() { void Semaphore::Acquire(Thread* thread) {
ASSERT_MSG(!ShouldWait(), "object unavailable!"); ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
--available_count; --available_count;
} }

View file

@ -39,8 +39,8 @@ public:
s32 available_count; ///< Number of free slots left in the semaphore s32 available_count; ///< Number of free slots left in the semaphore
std::string name; ///< Name of semaphore (optional) std::string name; ///< Name of semaphore (optional)
bool ShouldWait() override; bool ShouldWait(Thread* thread) const override;
void Acquire() override; void Acquire(Thread* thread) override;
/** /**
* Releases a certain number of slots from a semaphore. * Releases a certain number of slots from a semaphore.

View file

@ -14,13 +14,13 @@ namespace Kernel {
ServerPort::ServerPort() {} ServerPort::ServerPort() {}
ServerPort::~ServerPort() {} ServerPort::~ServerPort() {}
bool ServerPort::ShouldWait() { bool ServerPort::ShouldWait(Thread* thread) const {
// If there are no pending sessions, we wait until a new one is added. // If there are no pending sessions, we wait until a new one is added.
return pending_sessions.size() == 0; return pending_sessions.size() == 0;
} }
void ServerPort::Acquire() { void ServerPort::Acquire(Thread* thread) {
ASSERT_MSG(!ShouldWait(), "object unavailable!"); ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
} }
std::tuple<SharedPtr<ServerPort>, SharedPtr<ClientPort>> ServerPort::CreatePortPair( std::tuple<SharedPtr<ServerPort>, SharedPtr<ClientPort>> ServerPort::CreatePortPair(

View file

@ -53,8 +53,8 @@ public:
/// ServerSessions created from this port inherit a reference to this handler. /// ServerSessions created from this port inherit a reference to this handler.
std::shared_ptr<Service::SessionRequestHandler> hle_handler; std::shared_ptr<Service::SessionRequestHandler> hle_handler;
bool ShouldWait() override; bool ShouldWait(Thread* thread) const override;
void Acquire() override; void Acquire(Thread* thread) override;
private: private:
ServerPort(); ServerPort();

View file

@ -29,12 +29,12 @@ ResultVal<SharedPtr<ServerSession>> ServerSession::Create(
return MakeResult<SharedPtr<ServerSession>>(std::move(server_session)); return MakeResult<SharedPtr<ServerSession>>(std::move(server_session));
} }
bool ServerSession::ShouldWait() { bool ServerSession::ShouldWait(Thread* thread) const {
return !signaled; return !signaled;
} }
void ServerSession::Acquire() { void ServerSession::Acquire(Thread* thread) {
ASSERT_MSG(!ShouldWait(), "object unavailable!"); ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
signaled = false; signaled = false;
} }

View file

@ -57,9 +57,9 @@ public:
*/ */
ResultCode HandleSyncRequest(); ResultCode HandleSyncRequest();
bool ShouldWait() override; bool ShouldWait(Thread* thread) const override;
void Acquire() override; void Acquire(Thread* thread) override;
std::string name; ///< The name of this session (optional) std::string name; ///< The name of this session (optional)
bool signaled; ///< Whether there's new data available to this ServerSession bool signaled; ///< Whether there's new data available to this ServerSession

View file

@ -27,12 +27,12 @@ namespace Kernel {
/// Event type for the thread wake up event /// Event type for the thread wake up event
static int ThreadWakeupEventType; static int ThreadWakeupEventType;
bool Thread::ShouldWait() { bool Thread::ShouldWait(Thread* thread) const {
return status != THREADSTATUS_DEAD; return status != THREADSTATUS_DEAD;
} }
void Thread::Acquire() { void Thread::Acquire(Thread* thread) {
ASSERT_MSG(!ShouldWait(), "object unavailable!"); ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
} }
// TODO(yuriks): This can be removed if Thread objects are explicitly pooled in the future, allowing // TODO(yuriks): This can be removed if Thread objects are explicitly pooled in the future, allowing

View file

@ -72,8 +72,8 @@ public:
return HANDLE_TYPE; return HANDLE_TYPE;
} }
bool ShouldWait() override; bool ShouldWait(Thread* thread) const override;
void Acquire() override; void Acquire(Thread* thread) override;
/** /**
* Gets the thread's current priority * Gets the thread's current priority

View file

@ -39,12 +39,12 @@ SharedPtr<Timer> Timer::Create(ResetType reset_type, std::string name) {
return timer; return timer;
} }
bool Timer::ShouldWait() { bool Timer::ShouldWait(Thread* thread) const {
return !signaled; return !signaled;
} }
void Timer::Acquire() { void Timer::Acquire(Thread* thread) {
ASSERT_MSG(!ShouldWait(), "object unavailable!"); ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
if (reset_type == ResetType::OneShot) if (reset_type == ResetType::OneShot)
signaled = false; signaled = false;

View file

@ -39,8 +39,8 @@ public:
u64 initial_delay; ///< The delay until the timer fires for the first time u64 initial_delay; ///< The delay until the timer fires for the first time
u64 interval_delay; ///< The delay until the timer fires after the first time u64 interval_delay; ///< The delay until the timer fires after the first time
bool ShouldWait() override; bool ShouldWait(Thread* thread) const override;
void Acquire() override; void Acquire(Thread* thread) override;
/** /**
* Starts the timer, with the specified initial delay and interval. * Starts the timer, with the specified initial delay and interval.

View file

@ -272,7 +272,7 @@ static ResultCode WaitSynchronization1(Kernel::Handle handle, s64 nano_seconds)
LOG_TRACE(Kernel_SVC, "called handle=0x%08X(%s:%s), nanoseconds=%lld", handle, LOG_TRACE(Kernel_SVC, "called handle=0x%08X(%s:%s), nanoseconds=%lld", handle,
object->GetTypeName().c_str(), object->GetName().c_str(), nano_seconds); object->GetTypeName().c_str(), object->GetName().c_str(), nano_seconds);
if (object->ShouldWait()) { if (object->ShouldWait(thread)) {
if (nano_seconds == 0) if (nano_seconds == 0)
return ERR_SYNC_TIMEOUT; return ERR_SYNC_TIMEOUT;
@ -294,7 +294,7 @@ static ResultCode WaitSynchronization1(Kernel::Handle handle, s64 nano_seconds)
return ERR_SYNC_TIMEOUT; return ERR_SYNC_TIMEOUT;
} }
object->Acquire(); object->Acquire(thread);
return RESULT_SUCCESS; return RESULT_SUCCESS;
} }
@ -336,11 +336,11 @@ static ResultCode WaitSynchronizationN(s32* out, Kernel::Handle* handles, s32 ha
if (wait_all) { if (wait_all) {
bool all_available = bool all_available =
std::all_of(objects.begin(), objects.end(), std::all_of(objects.begin(), objects.end(),
[](const ObjectPtr& object) { return !object->ShouldWait(); }); [thread](const ObjectPtr& object) { return !object->ShouldWait(thread); });
if (all_available) { if (all_available) {
// We can acquire all objects right now, do so. // We can acquire all objects right now, do so.
for (auto& object : objects) for (auto& object : objects)
object->Acquire(); object->Acquire(thread);
// Note: In this case, the `out` parameter is not set, // Note: In this case, the `out` parameter is not set,
// and retains whatever value it had before. // and retains whatever value it had before.
return RESULT_SUCCESS; return RESULT_SUCCESS;
@ -380,12 +380,12 @@ static ResultCode WaitSynchronizationN(s32* out, Kernel::Handle* handles, s32 ha
} else { } else {
// Find the first object that is acquirable in the provided list of objects // Find the first object that is acquirable in the provided list of objects
auto itr = std::find_if(objects.begin(), objects.end(), auto itr = std::find_if(objects.begin(), objects.end(),
[](const ObjectPtr& object) { return !object->ShouldWait(); }); [thread](const ObjectPtr& object) { return !object->ShouldWait(thread); });
if (itr != objects.end()) { if (itr != objects.end()) {
// We found a ready object, acquire it and set the result value // We found a ready object, acquire it and set the result value
Kernel::WaitObject* object = itr->get(); Kernel::WaitObject* object = itr->get();
object->Acquire(); object->Acquire(thread);
*out = std::distance(objects.begin(), itr); *out = std::distance(objects.begin(), itr);
return RESULT_SUCCESS; return RESULT_SUCCESS;
} }