Skip to content

Commit

Permalink
fixup! KTOR-8105 Fix for concurrent flush attempts breaking CIO clien…
Browse files Browse the repository at this point in the history
…t (again)
  • Loading branch information
bjhham committed Feb 13, 2025
1 parent e4f9fdf commit 316806c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ internal class TLSClientHandshake(
rawOutput: ByteWriteChannel,
private val config: TLSConfig,
override val coroutineContext: CoroutineContext,
private val closeDeferred: CompletableDeferred<Unit> = CompletableDeferred<Unit>(),
private val closeTask: CompletableJob = Job(),
) : CoroutineScope {
private val digest = Digest()
private val clientSeed: ByteArray = config.random.generateClientSeed()
Expand Down Expand Up @@ -129,15 +129,15 @@ internal class TLSClientHandshake(
val record = if (useCipher) cipher.encrypt(closeRecord) else closeRecord
rawOutput.writeRecord(record)
rawOutput.flushAndClose()
closeDeferred.complete(Unit)
closeTask.complete()
}
}
}

fun close(): Deferred<Unit> {
fun close(): Job {
input.cancel()
output.close()
return closeDeferred
return closeTask
}

@OptIn(ExperimentalCoroutinesApi::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ import io.ktor.network.util.*
import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import io.ktor.utils.io.pool.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import java.nio.*
import kotlin.coroutines.*
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.channels.ClosedSendChannelException
import kotlinx.coroutines.channels.consumeEach
import java.nio.ByteBuffer
import kotlin.coroutines.CoroutineContext

internal actual suspend fun openTLSSession(
socket: Socket,
Expand All @@ -26,11 +28,11 @@ internal actual suspend fun openTLSSession(
handshake.negotiate()
} catch (cause: Exception) {
runCatching {
handshake.close().await()
handshake.close().join()
socket.close()
}
if (cause is ClosedSendChannelException) {
throw TlsException("Negotiation failed due to EOS", cause)
throw TLSException("Negotiation failed due to EOS", cause)
} else {
throw cause
}
Expand Down Expand Up @@ -68,7 +70,7 @@ private class TLSSocket(
pipe.writePacket(record.packet)
pipe.flush()
}
else -> throw TlsException("Unexpected record ${record.type} ($length bytes)")
else -> throw TLSException("Unexpected record ${record.type} ($length bytes)")
}
}
} catch (_: Throwable) {
Expand Down

0 comments on commit 316806c

Please sign in to comment.