Tests: Add tests for fibers and refactor/fix Fiber class
This commit is contained in:
		
							parent
							
								
									353166d648
								
							
						
					
					
						commit
						279ff1c0ff
					
				
					 4 changed files with 247 additions and 19 deletions
				
			
		|  | @ -3,18 +3,21 @@ | |||
| // Refer to the license.txt file included.
 | ||||
| 
 | ||||
| #include "common/fiber.h" | ||||
| #ifdef _MSC_VER | ||||
| #include <windows.h> | ||||
| #else | ||||
| #include <boost/context/detail/fcontext.hpp> | ||||
| #endif | ||||
| 
 | ||||
| namespace Common { | ||||
| 
 | ||||
| #ifdef _MSC_VER | ||||
| #include <windows.h> | ||||
| 
 | ||||
| struct Fiber::FiberImpl { | ||||
|     LPVOID handle = nullptr; | ||||
| }; | ||||
| 
 | ||||
| void Fiber::_start([[maybe_unused]] void* parameter) { | ||||
|     guard.lock(); | ||||
| void Fiber::start() { | ||||
|     if (previous_fiber) { | ||||
|         previous_fiber->guard.unlock(); | ||||
|         previous_fiber = nullptr; | ||||
|  | @ -22,10 +25,10 @@ void Fiber::_start([[maybe_unused]] void* parameter) { | |||
|     entry_point(start_parameter); | ||||
| } | ||||
| 
 | ||||
| static void __stdcall FiberStartFunc(LPVOID lpFiberParameter) | ||||
| void __stdcall Fiber::FiberStartFunc(void* fiber_parameter) | ||||
| { | ||||
|    auto fiber = static_cast<Fiber *>(lpFiberParameter); | ||||
|    fiber->_start(nullptr); | ||||
|    auto fiber = static_cast<Fiber *>(fiber_parameter); | ||||
|    fiber->start(); | ||||
| } | ||||
| 
 | ||||
| Fiber::Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter) | ||||
|  | @ -74,30 +77,26 @@ std::shared_ptr<Fiber> Fiber::ThreadToFiber() { | |||
| 
 | ||||
| #else | ||||
| 
 | ||||
| #include <boost/context/detail/fcontext.hpp> | ||||
| 
 | ||||
| constexpr std::size_t default_stack_size = 1024 * 1024 * 4; // 4MB
 | ||||
| 
 | ||||
| struct Fiber::FiberImpl { | ||||
|     boost::context::detail::fcontext_t context; | ||||
| struct alignas(64) Fiber::FiberImpl { | ||||
|     std::array<u8, default_stack_size> stack; | ||||
|     boost::context::detail::fcontext_t context; | ||||
| }; | ||||
| 
 | ||||
| void Fiber::_start(void* parameter) { | ||||
|     guard.lock(); | ||||
|     boost::context::detail::transfer_t* transfer = static_cast<boost::context::detail::transfer_t*>(parameter); | ||||
| void Fiber::start(boost::context::detail::transfer_t& transfer) { | ||||
|     if (previous_fiber) { | ||||
|         previous_fiber->impl->context = transfer->fctx; | ||||
|         previous_fiber->impl->context = transfer.fctx; | ||||
|         previous_fiber->guard.unlock(); | ||||
|         previous_fiber = nullptr; | ||||
|     } | ||||
|     entry_point(start_parameter); | ||||
| } | ||||
| 
 | ||||
| static void FiberStartFunc(boost::context::detail::transfer_t transfer) | ||||
| void Fiber::FiberStartFunc(boost::context::detail::transfer_t transfer) | ||||
| { | ||||
|    auto fiber = static_cast<Fiber *>(transfer.data); | ||||
|    fiber->_start(&transfer); | ||||
|    fiber->start(transfer); | ||||
| } | ||||
| 
 | ||||
| Fiber::Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter) | ||||
|  | @ -139,6 +138,7 @@ void Fiber::YieldTo(std::shared_ptr<Fiber> from, std::shared_ptr<Fiber> to) { | |||
| 
 | ||||
| std::shared_ptr<Fiber> Fiber::ThreadToFiber() { | ||||
|     std::shared_ptr<Fiber> fiber = std::shared_ptr<Fiber>{new Fiber()}; | ||||
|     fiber->guard.lock(); | ||||
|     fiber->is_thread_fiber = true; | ||||
|     return fiber; | ||||
| } | ||||
|  |  | |||
|  | @ -10,6 +10,12 @@ | |||
| #include "common/common_types.h" | ||||
| #include "common/spin_lock.h" | ||||
| 
 | ||||
| #ifndef _MSC_VER | ||||
| namespace boost::context::detail { | ||||
|     struct transfer_t; | ||||
| } | ||||
| #endif | ||||
| 
 | ||||
| namespace Common { | ||||
| 
 | ||||
| class Fiber { | ||||
|  | @ -31,9 +37,6 @@ public: | |||
|     /// Only call from main thread's fiber
 | ||||
|     void Exit(); | ||||
| 
 | ||||
|     /// Used internally but required to be public, Shall not be used
 | ||||
|     void _start(void* parameter); | ||||
| 
 | ||||
|     /// Changes the start parameter of the fiber. Has no effect if the fiber already started
 | ||||
|     void SetStartParameter(void* new_parameter) { | ||||
|         start_parameter = new_parameter; | ||||
|  | @ -42,6 +45,16 @@ public: | |||
| private: | ||||
|     Fiber(); | ||||
| 
 | ||||
| #ifdef _MSC_VER | ||||
|     void start(); | ||||
|     static void FiberStartFunc(void* fiber_parameter); | ||||
| #else | ||||
|     void start(boost::context::detail::transfer_t& transfer); | ||||
|     static void FiberStartFunc(boost::context::detail::transfer_t transfer); | ||||
| #endif | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     struct FiberImpl; | ||||
| 
 | ||||
|     SpinLock guard; | ||||
|  |  | |||
|  | @ -1,6 +1,7 @@ | |||
| add_executable(tests | ||||
|     common/bit_field.cpp | ||||
|     common/bit_utils.cpp | ||||
|     common/fibers.cpp | ||||
|     common/multi_level_queue.cpp | ||||
|     common/param_package.cpp | ||||
|     common/ring_buffer.cpp | ||||
|  |  | |||
							
								
								
									
										214
									
								
								src/tests/common/fibers.cpp
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										214
									
								
								src/tests/common/fibers.cpp
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,214 @@ | |||
| // Copyright 2020 yuzu Emulator Project
 | ||||
| // Licensed under GPLv2 or any later version
 | ||||
| // Refer to the license.txt file included.
 | ||||
| 
 | ||||
| #include <atomic> | ||||
| #include <cstdlib> | ||||
| #include <functional> | ||||
| #include <memory> | ||||
| #include <thread> | ||||
| #include <unordered_map> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include <catch2/catch.hpp> | ||||
| #include <math.h> | ||||
| #include "common/common_types.h" | ||||
| #include "common/fiber.h" | ||||
| #include "common/spin_lock.h" | ||||
| 
 | ||||
| namespace Common { | ||||
| 
 | ||||
| class TestControl1 { | ||||
| public: | ||||
|     TestControl1() = default; | ||||
| 
 | ||||
|     void DoWork(); | ||||
| 
 | ||||
|     void ExecuteThread(u32 id); | ||||
| 
 | ||||
|     std::unordered_map<std::thread::id, u32> ids; | ||||
|     std::vector<std::shared_ptr<Common::Fiber>> thread_fibers; | ||||
|     std::vector<std::shared_ptr<Common::Fiber>> work_fibers; | ||||
|     std::vector<u32> items; | ||||
|     std::vector<u32> results; | ||||
| }; | ||||
| 
 | ||||
| static void WorkControl1(void* control) { | ||||
|     TestControl1* test_control = static_cast<TestControl1*>(control); | ||||
|     test_control->DoWork(); | ||||
| } | ||||
| 
 | ||||
| void TestControl1::DoWork() { | ||||
|     std::thread::id this_id = std::this_thread::get_id(); | ||||
|     u32 id = ids[this_id]; | ||||
|     u32 value = items[id]; | ||||
|     for (u32 i = 0; i < id; i++) { | ||||
|         value++; | ||||
|     } | ||||
|     results[id] = value; | ||||
|     Fiber::YieldTo(work_fibers[id], thread_fibers[id]); | ||||
| } | ||||
| 
 | ||||
| void TestControl1::ExecuteThread(u32 id) { | ||||
|     std::thread::id this_id = std::this_thread::get_id(); | ||||
|     ids[this_id] = id; | ||||
|     auto thread_fiber = Fiber::ThreadToFiber(); | ||||
|     thread_fibers[id] = thread_fiber; | ||||
|     work_fibers[id] = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl1}, this); | ||||
|     items[id] = rand() % 256; | ||||
|     Fiber::YieldTo(thread_fibers[id], work_fibers[id]); | ||||
|     thread_fibers[id]->Exit(); | ||||
| } | ||||
| 
 | ||||
| static void ThreadStart1(u32 id, TestControl1& test_control) { | ||||
|     test_control.ExecuteThread(id); | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| TEST_CASE("Fibers::Setup", "[common]") { | ||||
|     constexpr u32 num_threads = 7; | ||||
|     TestControl1 test_control{}; | ||||
|     test_control.thread_fibers.resize(num_threads, nullptr); | ||||
|     test_control.work_fibers.resize(num_threads, nullptr); | ||||
|     test_control.items.resize(num_threads, 0); | ||||
|     test_control.results.resize(num_threads, 0); | ||||
|     std::vector<std::thread> threads; | ||||
|     for (u32 i = 0; i < num_threads; i++) { | ||||
|         threads.emplace_back(ThreadStart1, i, std::ref(test_control)); | ||||
|     } | ||||
|     for (u32 i = 0; i < num_threads; i++) { | ||||
|         threads[i].join(); | ||||
|     } | ||||
|     for (u32 i = 0; i < num_threads; i++) { | ||||
|         REQUIRE(test_control.items[i] + i == test_control.results[i]); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| class TestControl2 { | ||||
| public: | ||||
|     TestControl2() = default; | ||||
| 
 | ||||
|     void DoWork1() { | ||||
|         trap2 = false; | ||||
|         while (trap.load()); | ||||
|         for (u32 i = 0; i < 12000; i++) { | ||||
|             value1 += i; | ||||
|         } | ||||
|         Fiber::YieldTo(fiber1, fiber3); | ||||
|         std::thread::id this_id = std::this_thread::get_id(); | ||||
|         u32 id = ids[this_id]; | ||||
|         assert1 = id == 1; | ||||
|         value2 += 5000; | ||||
|         Fiber::YieldTo(fiber1, thread_fibers[id]); | ||||
|     } | ||||
| 
 | ||||
|     void DoWork2() { | ||||
|         while (trap2.load()); | ||||
|         value2 = 2000; | ||||
|         trap = false; | ||||
|         Fiber::YieldTo(fiber2, fiber1); | ||||
|         assert3 = false; | ||||
|     } | ||||
| 
 | ||||
|     void DoWork3() { | ||||
|         std::thread::id this_id = std::this_thread::get_id(); | ||||
|         u32 id = ids[this_id]; | ||||
|         assert2 = id == 0; | ||||
|         value1 += 1000; | ||||
|         Fiber::YieldTo(fiber3, thread_fibers[id]); | ||||
|     } | ||||
| 
 | ||||
|     void ExecuteThread(u32 id); | ||||
| 
 | ||||
|     void CallFiber1() { | ||||
|         std::thread::id this_id = std::this_thread::get_id(); | ||||
|         u32 id = ids[this_id]; | ||||
|         Fiber::YieldTo(thread_fibers[id], fiber1); | ||||
|     } | ||||
| 
 | ||||
|     void CallFiber2() { | ||||
|         std::thread::id this_id = std::this_thread::get_id(); | ||||
|         u32 id = ids[this_id]; | ||||
|         Fiber::YieldTo(thread_fibers[id], fiber2); | ||||
|     } | ||||
| 
 | ||||
|     void Exit(); | ||||
| 
 | ||||
|     bool assert1{}; | ||||
|     bool assert2{}; | ||||
|     bool assert3{true}; | ||||
|     u32 value1{}; | ||||
|     u32 value2{}; | ||||
|     std::atomic<bool> trap{true}; | ||||
|     std::atomic<bool> trap2{true}; | ||||
|     std::unordered_map<std::thread::id, u32> ids; | ||||
|     std::vector<std::shared_ptr<Common::Fiber>> thread_fibers; | ||||
|     std::shared_ptr<Common::Fiber> fiber1; | ||||
|     std::shared_ptr<Common::Fiber> fiber2; | ||||
|     std::shared_ptr<Common::Fiber> fiber3; | ||||
| }; | ||||
| 
 | ||||
| static void WorkControl2_1(void* control) { | ||||
|     TestControl2* test_control = static_cast<TestControl2*>(control); | ||||
|     test_control->DoWork1(); | ||||
| } | ||||
| 
 | ||||
| static void WorkControl2_2(void* control) { | ||||
|     TestControl2* test_control = static_cast<TestControl2*>(control); | ||||
|     test_control->DoWork2(); | ||||
| } | ||||
| 
 | ||||
| static void WorkControl2_3(void* control) { | ||||
|     TestControl2* test_control = static_cast<TestControl2*>(control); | ||||
|     test_control->DoWork3(); | ||||
| } | ||||
| 
 | ||||
| void TestControl2::ExecuteThread(u32 id) { | ||||
|     std::thread::id this_id = std::this_thread::get_id(); | ||||
|     ids[this_id] = id; | ||||
|     auto thread_fiber = Fiber::ThreadToFiber(); | ||||
|     thread_fibers[id] = thread_fiber; | ||||
| } | ||||
| 
 | ||||
| void TestControl2::Exit() { | ||||
|     std::thread::id this_id = std::this_thread::get_id(); | ||||
|     u32 id = ids[this_id]; | ||||
|     thread_fibers[id]->Exit(); | ||||
| } | ||||
| 
 | ||||
| static void ThreadStart2_1(u32 id, TestControl2& test_control) { | ||||
|     test_control.ExecuteThread(id); | ||||
|     test_control.CallFiber1(); | ||||
|     test_control.Exit(); | ||||
| } | ||||
| 
 | ||||
| static void ThreadStart2_2(u32 id, TestControl2& test_control) { | ||||
|     test_control.ExecuteThread(id); | ||||
|     test_control.CallFiber2(); | ||||
|     test_control.Exit(); | ||||
| } | ||||
| 
 | ||||
| TEST_CASE("Fibers::InterExchange", "[common]") { | ||||
|     TestControl2 test_control{}; | ||||
|     test_control.thread_fibers.resize(2, nullptr); | ||||
|     test_control.fiber1 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_1}, &test_control); | ||||
|     test_control.fiber2 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_2}, &test_control); | ||||
|     test_control.fiber3 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_3}, &test_control); | ||||
|     std::thread thread1(ThreadStart2_1, 0, std::ref(test_control)); | ||||
|     std::thread thread2(ThreadStart2_2, 1, std::ref(test_control)); | ||||
|     thread1.join(); | ||||
|     thread2.join(); | ||||
|     REQUIRE(test_control.assert1); | ||||
|     REQUIRE(test_control.assert2); | ||||
|     REQUIRE(test_control.assert3); | ||||
|     REQUIRE(test_control.value2 == 7000); | ||||
|     u32 cal_value = 0; | ||||
|     for (u32 i = 0; i < 12000; i++) { | ||||
|         cal_value += i; | ||||
|     } | ||||
|     cal_value += 1000; | ||||
|     REQUIRE(test_control.value1 == cal_value); | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| } // namespace Common
 | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Fernando Sahmkow
						Fernando Sahmkow