forked from eden-emu/eden
		
	Merge pull request #11385 from liamwhite/acceptcancel
internal_network: cancel pending socket operations on application process termination
This commit is contained in:
		
						commit
						2a0025937d
					
				
					 3 changed files with 91 additions and 3 deletions
				
			
		|  | @ -406,6 +406,7 @@ struct System::Impl { | |||
|             gpu_core->NotifyShutdown(); | ||||
|         } | ||||
| 
 | ||||
|         Network::CancelPendingSocketOperations(); | ||||
|         kernel.SuspendApplication(true); | ||||
|         if (services) { | ||||
|             services->KillNVNFlinger(); | ||||
|  | @ -427,6 +428,7 @@ struct System::Impl { | |||
|         debugger.reset(); | ||||
|         kernel.Shutdown(); | ||||
|         memory.Reset(); | ||||
|         Network::RestartSocketOperations(); | ||||
| 
 | ||||
|         if (auto room_member = room_network.GetRoomMember().lock()) { | ||||
|             Network::GameInfo game_info{}; | ||||
|  |  | |||
|  | @ -48,15 +48,32 @@ enum class CallType { | |||
| 
 | ||||
| using socklen_t = int; | ||||
| 
 | ||||
| SOCKET interrupt_socket = static_cast<SOCKET>(-1); | ||||
| 
 | ||||
| void InterruptSocketOperations() { | ||||
|     closesocket(interrupt_socket); | ||||
| } | ||||
| 
 | ||||
| void AcknowledgeInterrupt() { | ||||
|     interrupt_socket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); | ||||
| } | ||||
| 
 | ||||
| void Initialize() { | ||||
|     WSADATA wsa_data; | ||||
|     (void)WSAStartup(MAKEWORD(2, 2), &wsa_data); | ||||
| 
 | ||||
|     AcknowledgeInterrupt(); | ||||
| } | ||||
| 
 | ||||
| void Finalize() { | ||||
|     InterruptSocketOperations(); | ||||
|     WSACleanup(); | ||||
| } | ||||
| 
 | ||||
| SOCKET GetInterruptSocket() { | ||||
|     return interrupt_socket; | ||||
| } | ||||
| 
 | ||||
| sockaddr TranslateFromSockAddrIn(SockAddrIn input) { | ||||
|     sockaddr_in result; | ||||
| 
 | ||||
|  | @ -157,9 +174,42 @@ constexpr int SD_RECEIVE = SHUT_RD; | |||
| constexpr int SD_SEND = SHUT_WR; | ||||
| constexpr int SD_BOTH = SHUT_RDWR; | ||||
| 
 | ||||
| void Initialize() {} | ||||
| int interrupt_pipe_fd[2] = {-1, -1}; | ||||
| 
 | ||||
| void Finalize() {} | ||||
| void Initialize() { | ||||
|     if (pipe(interrupt_pipe_fd) != 0) { | ||||
|         LOG_ERROR(Network, "Failed to create interrupt pipe!"); | ||||
|     } | ||||
|     int flags = fcntl(interrupt_pipe_fd[0], F_GETFL); | ||||
|     ASSERT_MSG(fcntl(interrupt_pipe_fd[0], F_SETFL, flags | O_NONBLOCK) == 0, | ||||
|                "Failed to set nonblocking state for interrupt pipe"); | ||||
| } | ||||
| 
 | ||||
| void Finalize() { | ||||
|     if (interrupt_pipe_fd[0] >= 0) { | ||||
|         close(interrupt_pipe_fd[0]); | ||||
|     } | ||||
|     if (interrupt_pipe_fd[1] >= 0) { | ||||
|         close(interrupt_pipe_fd[1]); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| void InterruptSocketOperations() { | ||||
|     u8 value = 0; | ||||
|     ASSERT(write(interrupt_pipe_fd[1], &value, sizeof(value)) == 1); | ||||
| } | ||||
| 
 | ||||
| void AcknowledgeInterrupt() { | ||||
|     u8 value = 0; | ||||
|     ssize_t ret = read(interrupt_pipe_fd[0], &value, sizeof(value)); | ||||
|     if (ret != 1 && errno != EAGAIN && errno != EWOULDBLOCK) { | ||||
|         LOG_ERROR(Network, "Failed to acknowledge interrupt on shutdown"); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| SOCKET GetInterruptSocket() { | ||||
|     return interrupt_pipe_fd[0]; | ||||
| } | ||||
| 
 | ||||
| sockaddr TranslateFromSockAddrIn(SockAddrIn input) { | ||||
|     sockaddr_in result; | ||||
|  | @ -490,6 +540,14 @@ NetworkInstance::~NetworkInstance() { | |||
|     Finalize(); | ||||
| } | ||||
| 
 | ||||
| void CancelPendingSocketOperations() { | ||||
|     InterruptSocketOperations(); | ||||
| } | ||||
| 
 | ||||
| void RestartSocketOperations() { | ||||
|     AcknowledgeInterrupt(); | ||||
| } | ||||
| 
 | ||||
| std::optional<IPv4Address> GetHostIPv4Address() { | ||||
|     const auto network_interface = Network::GetSelectedNetworkInterface(); | ||||
|     if (!network_interface.has_value()) { | ||||
|  | @ -560,7 +618,14 @@ std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) { | |||
|         return result; | ||||
|     }); | ||||
| 
 | ||||
|     const int result = WSAPoll(host_pollfds.data(), static_cast<ULONG>(num), timeout); | ||||
|     host_pollfds.push_back(WSAPOLLFD{ | ||||
|         .fd = GetInterruptSocket(), | ||||
|         .events = POLLIN, | ||||
|         .revents = 0, | ||||
|     }); | ||||
| 
 | ||||
|     const int result = | ||||
|         WSAPoll(host_pollfds.data(), static_cast<ULONG>(host_pollfds.size()), timeout); | ||||
|     if (result == 0) { | ||||
|         ASSERT(std::all_of(host_pollfds.begin(), host_pollfds.end(), | ||||
|                            [](WSAPOLLFD fd) { return fd.revents == 0; })); | ||||
|  | @ -627,6 +692,24 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { | |||
| std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() { | ||||
|     sockaddr_in addr; | ||||
|     socklen_t addrlen = sizeof(addr); | ||||
| 
 | ||||
|     std::vector<WSAPOLLFD> host_pollfds{ | ||||
|         WSAPOLLFD{fd, POLLIN, 0}, | ||||
|         WSAPOLLFD{GetInterruptSocket(), POLLIN, 0}, | ||||
|     }; | ||||
| 
 | ||||
|     while (true) { | ||||
|         const int pollres = | ||||
|             WSAPoll(host_pollfds.data(), static_cast<ULONG>(host_pollfds.size()), -1); | ||||
|         if (host_pollfds[1].revents != 0) { | ||||
|             // Interrupt signaled before a client could be accepted, break
 | ||||
|             return {AcceptResult{}, Errno::AGAIN}; | ||||
|         } | ||||
|         if (pollres > 0) { | ||||
|             break; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     const SOCKET new_socket = accept(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen); | ||||
| 
 | ||||
|     if (new_socket == INVALID_SOCKET) { | ||||
|  |  | |||
|  | @ -96,6 +96,9 @@ public: | |||
|     ~NetworkInstance(); | ||||
| }; | ||||
| 
 | ||||
| void CancelPendingSocketOperations(); | ||||
| void RestartSocketOperations(); | ||||
| 
 | ||||
| #ifdef _WIN32 | ||||
| constexpr IPv4Address TranslateIPv4(in_addr addr) { | ||||
|     auto& bytes = addr.S_un.S_un_b; | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 liamwhite
						liamwhite