diff --git a/Sources/NIOPosix/SocketChannel.swift b/Sources/NIOPosix/SocketChannel.swift index 51fdc3a87e..271c78cd4d 100644 --- a/Sources/NIOPosix/SocketChannel.swift +++ b/Sources/NIOPosix/SocketChannel.swift @@ -263,6 +263,7 @@ final class ServerSocketChannel: BaseSocketChannel { // It's important to call the methods before we actually notify the original promise for ordering reasons. self.becomeActive0(promise: promise) }.whenFailure{ error in + self.close0(error: error, mode: .all, promise: nil) promise?.fail(error) } executeAndComplete(p) { diff --git a/Tests/NIOPosixTests/AcceptBackoffHandlerTest.swift b/Tests/NIOPosixTests/AcceptBackoffHandlerTest.swift index 015c6382ed..7c69d951c6 100644 --- a/Tests/NIOPosixTests/AcceptBackoffHandlerTest.swift +++ b/Tests/NIOPosixTests/AcceptBackoffHandlerTest.swift @@ -265,7 +265,7 @@ public final class AcceptBackoffHandlerTest: XCTestCase { name: self.acceptHandlerName) }.wait()) - XCTAssertNoThrow(try eventLoop.flatSubmit { + let bindFuture = eventLoop.flatSubmit { // this is pretty delicate at the moment: // `bind` must be _synchronously_ follow `register`, otherwise in our current implementation, `epoll` will // send us `EPOLLHUP`. To have it run synchronously, we need to invoke the `flatMap` on the eventloop that the @@ -273,7 +273,11 @@ public final class AcceptBackoffHandlerTest: XCTestCase { serverChannel.register().flatMap { () -> EventLoopFuture<()> in return serverChannel.bind(to: try! SocketAddress(ipAddress: "127.0.0.1", port: 0)) } - }.wait() as Void) + } + + // If bind fails, the error will propagate up + try bindFuture.wait() + return serverChannel } } diff --git a/Tests/NIOPosixTests/SALChannelTests.swift b/Tests/NIOPosixTests/SALChannelTests.swift index 785d7d0d28..0e893fa543 100644 --- a/Tests/NIOPosixTests/SALChannelTests.swift +++ b/Tests/NIOPosixTests/SALChannelTests.swift @@ -322,6 +322,48 @@ final class SALChannelTest: XCTestCase, SALTest { } }.salWait()) } + + func testBindFailureClosesChannel() { + guard let channel = try? self.makeSocketChannel() else { + XCTFail("couldn't make a channel") + return + } + let localAddress = try! SocketAddress(ipAddress: "1.2.3.4", port: 5) + let serverAddress = try! SocketAddress(ipAddress: "9.8.7.6", port: 5) + + XCTAssertThrowsError(try channel.eventLoop.runSAL(syscallAssertions: { + try self.assertSetOption(expectedLevel: .tcp, expectedOption: .tcp_nodelay) { value in + return (value as? SocketOptionValue) == 1 + } + try self.assertSetOption(expectedLevel: .socket, expectedOption: .so_reuseaddr) { value in + return (value as? SocketOptionValue) == 1 + } + try self.assertBind(expectedAddress: localAddress, errorReturn: IOError(errnoCode: EPERM, reason: "bind")) + try self.assertDeregister { selectable in + return true + } + try self.assertClose(expectedFD: .max) + + }) { + ClientBootstrap(group: channel.eventLoop) + .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .channelOption(ChannelOptions.autoRead, value: false) + .bind(to: localAddress) + .testOnly_connect(injectedChannel: channel, to: serverAddress) + .flatMapError { error in + guard let ioError = error as? IOError else { + XCTFail("Expected IOError, got \(error)") + return channel.eventLoop.makeFailedFuture(error) + } + XCTAssertEqual(ioError.errnoCode, EPERM, "Expected EPERM error code") + XCTAssertTrue(!channel.isActive, "Channel should be closed") + return channel.eventLoop.makeFailedFuture(error) + } + .flatMap { channel in + channel.close() + } + }.salWait()) + } func testAcceptingInboundConnections() throws { final class ConnectionRecorder: ChannelInboundHandler { diff --git a/Tests/NIOPosixTests/SyscallAbstractionLayer.swift b/Tests/NIOPosixTests/SyscallAbstractionLayer.swift index 8da3d76b97..06c382ec47 100644 --- a/Tests/NIOPosixTests/SyscallAbstractionLayer.swift +++ b/Tests/NIOPosixTests/SyscallAbstractionLayer.swift @@ -874,9 +874,10 @@ extension SALTest { } } - func assertBind(expectedAddress: SocketAddress, file: StaticString = #filePath, line: UInt = #line) throws { + func assertBind(expectedAddress: SocketAddress, errorReturn: IOError? = nil, file: StaticString = #filePath, line: UInt = #line) throws { SAL.printIfDebug("\(#function)") - try self.selector.assertSyscallAndReturn(.returnVoid, file: (file), line: line) { syscall in + + try self.selector.assertSyscallAndReturn(errorReturn != nil ? .error(errorReturn!) : .returnVoid, file: file, line: line) { syscall in if case .bind(let address) = syscall { return address == expectedAddress } else {