From 183b66d521da018f5141e2814f488447d1a4cba8 Mon Sep 17 00:00:00 2001 From: Honfika Date: Sat, 14 Apr 2018 19:49:34 +0200 Subject: [PATCH] CancellationTokens added --- .../Network/ClientHelloAlpnAdderStream.cs | 4 +- StreamExtended/Network/CopyStream.cs | 19 ++++---- StreamExtended/Network/CustomBinaryReader.cs | 27 ++++++----- StreamExtended/Network/CustomBinaryWriter.cs | 9 ++-- .../Network/CustomBufferedPeekStream.cs | 13 ++--- .../Network/CustomBufferedStream.cs | 24 ++++------ StreamExtended/Network/IBufferedStream.cs | 7 +-- .../Network/ServerHelloAlpnAdderStream.cs | 4 +- StreamExtended/SslTools.cs | 47 ++++++++++--------- 9 files changed, 79 insertions(+), 75 deletions(-) diff --git a/StreamExtended/Network/ClientHelloAlpnAdderStream.cs b/StreamExtended/Network/ClientHelloAlpnAdderStream.cs index 1edf4e7..ca20026 100644 --- a/StreamExtended/Network/ClientHelloAlpnAdderStream.cs +++ b/StreamExtended/Network/ClientHelloAlpnAdderStream.cs @@ -1,5 +1,6 @@ using System.Diagnostics; using System.IO; +using System.Threading; namespace StreamExtended.Network { @@ -47,7 +48,8 @@ public override void Write(byte[] buffer, int offset, int count) var ms = new MemoryStream(buffer, offset, count); //this can be non async, because reads from a memory stream - var clientHello = SslTools.PeekClientHello(new CustomBufferedStream(ms, (int)ms.Length)).Result; + var cts = new CancellationTokenSource(); + var clientHello = SslTools.PeekClientHello(new CustomBufferedStream(ms, (int)ms.Length), cts.Token).Result; if (clientHello != null) { // 0x00 0x10: ALPN identifier diff --git a/StreamExtended/Network/CopyStream.cs b/StreamExtended/Network/CopyStream.cs index 9eeda9e..1be9c82 100644 --- a/StreamExtended/Network/CopyStream.cs +++ b/StreamExtended/Network/CopyStream.cs @@ -1,4 +1,5 @@ using System; +using System.Threading; using System.Threading.Tasks; using StreamExtended.Helpers; @@ -30,18 +31,18 @@ public CopyStream(CustomBinaryReader reader, CustomBinaryWriter writer, int buff buffer = BufferPool.GetBuffer(bufferSize); } - public async Task FillBufferAsync() + public async Task FillBufferAsync(CancellationToken cancellationToken = default(CancellationToken)) { - await FlushAsync(); - return await reader.FillBufferAsync(); + await FlushAsync(cancellationToken); + return await reader.FillBufferAsync(cancellationToken); } - public async Task FlushAsync() + public async Task FlushAsync(CancellationToken cancellationToken = default(CancellationToken)) { //send out the current data from from the buffer if (bufferLength > 0) { - await writer.WriteAsync(buffer, 0, bufferLength); + await writer.WriteAsync(buffer, 0, bufferLength, cancellationToken); bufferLength = 0; } } @@ -54,20 +55,20 @@ public byte ReadByteFromBuffer() return b; } - public async Task ReadAsync(byte[] buffer, int offset, int count) + public async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) { - int result = await reader.ReadBytesAsync(buffer, offset, count); + int result = await reader.ReadBytesAsync(buffer, offset, count, cancellationToken); if (result > 0) { if (bufferLength + result > bufferSize) { - await FlushAsync(); + await FlushAsync(cancellationToken); } Buffer.BlockCopy(buffer, offset, this.buffer, bufferLength, result); bufferLength += result; ReadBytes += result; - await FlushAsync(); + await FlushAsync(cancellationToken); } return result; diff --git a/StreamExtended/Network/CustomBinaryReader.cs b/StreamExtended/Network/CustomBinaryReader.cs index 6a99be1..e56fe48 100644 --- a/StreamExtended/Network/CustomBinaryReader.cs +++ b/StreamExtended/Network/CustomBinaryReader.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using System.Text; +using System.Threading; using System.Threading.Tasks; namespace StreamExtended.Network @@ -33,7 +34,7 @@ public CustomBinaryReader(IBufferedStream stream, int bufferSize) /// Read a line from the byte stream /// /// - public async Task ReadLineAsync() + public async Task ReadLineAsync(CancellationToken cancellationToken = default(CancellationToken)) { byte lastChar = default(byte); @@ -42,7 +43,7 @@ public async Task ReadLineAsync() // try to use the thread static buffer, usually it is enough var buffer = Buffer; - while (stream.DataAvailable || await stream.FillBufferAsync()) + while (stream.DataAvailable || await stream.FillBufferAsync(cancellationToken)) { byte newChar = stream.ReadByteFromBuffer(); buffer[bufferDataLength] = newChar; @@ -87,11 +88,11 @@ public async Task ReadLineAsync() /// Read until the last new line /// /// - public async Task> ReadAllLinesAsync() + public async Task> ReadAllLinesAsync(CancellationToken cancellationToken = default(CancellationToken)) { string tmpLine; var requestLines = new List(); - while (!string.IsNullOrEmpty(tmpLine = await ReadLineAsync())) + while (!string.IsNullOrEmpty(tmpLine = await ReadLineAsync(cancellationToken))) { requestLines.Add(tmpLine); } @@ -103,9 +104,9 @@ public async Task> ReadAllLinesAsync() /// Read until the last new line, ignores the result /// /// - public async Task ReadAndIgnoreAllLinesAsync() + public async Task ReadAndIgnoreAllLinesAsync(CancellationToken cancellationToken = default(CancellationToken)) { - while (!string.IsNullOrEmpty(await ReadLineAsync())) + while (!string.IsNullOrEmpty(await ReadLineAsync(cancellationToken))) { } } @@ -115,10 +116,11 @@ public async Task ReadAndIgnoreAllLinesAsync() /// /// /// + /// /// The number of bytes read - public Task ReadBytesAsync(byte[] buffer, int bytesToRead) + public Task ReadBytesAsync(byte[] buffer, int bytesToRead, CancellationToken cancellationToken = default(CancellationToken)) { - return stream.ReadAsync(buffer, 0, bytesToRead); + return stream.ReadAsync(buffer, 0, bytesToRead, cancellationToken); } /// @@ -127,10 +129,11 @@ public Task ReadBytesAsync(byte[] buffer, int bytesToRead) /// /// /// + /// /// The number of bytes read - public Task ReadBytesAsync(byte[] buffer, int offset, int bytesToRead) + public Task ReadBytesAsync(byte[] buffer, int offset, int bytesToRead, CancellationToken cancellationToken = default(CancellationToken)) { - return stream.ReadAsync(buffer, offset, bytesToRead); + return stream.ReadAsync(buffer, offset, bytesToRead, cancellationToken); } public bool DataAvailable => stream.DataAvailable; @@ -139,9 +142,9 @@ public Task ReadBytesAsync(byte[] buffer, int offset, int bytesToRead) /// Fills the buffer asynchronous. /// /// - public Task FillBufferAsync() + public Task FillBufferAsync(CancellationToken cancellationToken = default(CancellationToken)) { - return stream.FillBufferAsync(); + return stream.FillBufferAsync(cancellationToken); } public byte ReadByteFromBuffer() diff --git a/StreamExtended/Network/CustomBinaryWriter.cs b/StreamExtended/Network/CustomBinaryWriter.cs index 358edc5..56e22f4 100644 --- a/StreamExtended/Network/CustomBinaryWriter.cs +++ b/StreamExtended/Network/CustomBinaryWriter.cs @@ -1,4 +1,5 @@ using System.IO; +using System.Threading; using System.Threading.Tasks; namespace StreamExtended.Network @@ -12,14 +13,14 @@ protected CustomBinaryWriter(Stream stream) this.stream = stream; } - public Task WriteAsync(byte[] data, int offset, int count) + public Task WriteAsync(byte[] data, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) { - return stream.WriteAsync(data, offset, count); + return stream.WriteAsync(data, offset, count, cancellationToken); } - protected Task FlushAsync() + protected Task FlushAsync(CancellationToken cancellationToken = default(CancellationToken)) { - return stream.FlushAsync(); + return stream.FlushAsync(cancellationToken); } } } diff --git a/StreamExtended/Network/CustomBufferedPeekStream.cs b/StreamExtended/Network/CustomBufferedPeekStream.cs index 6292192..00d52a8 100644 --- a/StreamExtended/Network/CustomBufferedPeekStream.cs +++ b/StreamExtended/Network/CustomBufferedPeekStream.cs @@ -1,4 +1,5 @@ using System; +using System.Threading; using System.Threading.Tasks; namespace StreamExtended.Network @@ -25,9 +26,9 @@ internal CustomBufferedPeekStream(CustomBufferedStream baseStream, int startPosi /// internal int Available => baseStream.Available - Position; - internal async Task EnsureBufferLength(int length) + internal async Task EnsureBufferLength(int length, CancellationToken cancellationToken) { - var val = await baseStream.PeekByteAsync(Position + length - 1); + var val = await baseStream.PeekByteAsync(Position + length - 1, cancellationToken); return val != -1; } @@ -66,9 +67,9 @@ internal byte[] ReadBytes(int length) /// Fills the buffer asynchronous. /// /// - Task IBufferedStream.FillBufferAsync() + Task IBufferedStream.FillBufferAsync(CancellationToken cancellationToken = default(CancellationToken)) { - return baseStream.FillBufferAsync(); + return baseStream.FillBufferAsync(cancellationToken); } /// @@ -81,9 +82,9 @@ byte IBufferedStream.ReadByteFromBuffer() return ReadByte(); } - Task IBufferedStream.ReadAsync(byte[] buffer, int offset, int count) + Task IBufferedStream.ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - return baseStream.ReadAsync(buffer, offset, count); + return baseStream.ReadAsync(buffer, offset, count, cancellationToken); } } } diff --git a/StreamExtended/Network/CustomBufferedStream.cs b/StreamExtended/Network/CustomBufferedStream.cs index ff1da7f..0edcc64 100644 --- a/StreamExtended/Network/CustomBufferedStream.cs +++ b/StreamExtended/Network/CustomBufferedStream.cs @@ -122,7 +122,7 @@ public override void Write(byte[] buffer, int offset, int count) /// /// A task that represents the asynchronous copy operation. /// - public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken = default(CancellationToken)) { if (bufferLength > 0) { @@ -140,7 +140,7 @@ public override async Task CopyToAsync(Stream destination, int bufferSize, Cance /// /// A task that represents the asynchronous flush operation. /// - public override Task FlushAsync(CancellationToken cancellationToken) + public override Task FlushAsync(CancellationToken cancellationToken = default(CancellationToken)) { return baseStream.FlushAsync(cancellationToken); } @@ -165,7 +165,7 @@ public override Task FlushAsync(CancellationToken cancellationToken) /// less than the requested number, or it can be 0 (zero) /// if the end of the stream has been reached. /// - public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) { if (bufferLength == 0) { @@ -209,12 +209,13 @@ public override int ReadByte() /// Peeks a byte asynchronous. /// /// The index. + /// The cancellation token. /// - public async Task PeekByteAsync(int index) + public async Task PeekByteAsync(int index, CancellationToken cancellationToken = default(CancellationToken)) { if (Available <= index) { - await FillBufferAsync(); + await FillBufferAsync(cancellationToken); } if (Available <= index) @@ -268,7 +269,7 @@ public byte ReadByteFromBuffer() /// A task that represents the asynchronous write operation. /// [DebuggerStepThrough] - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) { OnDataSent(buffer, offset, count); return baseStream.WriteAsync(buffer, offset, count, cancellationToken); @@ -417,21 +418,12 @@ public bool FillBuffer() } } - /// - /// Fills the buffer asynchronous. - /// - /// - public Task FillBufferAsync() - { - return FillBufferAsync(CancellationToken.None); - } - /// /// Fills the buffer asynchronous. /// /// The cancellation token. /// - public async Task FillBufferAsync(CancellationToken cancellationToken) + public async Task FillBufferAsync(CancellationToken cancellationToken = default(CancellationToken)) { if (closed) { diff --git a/StreamExtended/Network/IBufferedStream.cs b/StreamExtended/Network/IBufferedStream.cs index 63d58c2..7ca0b24 100644 --- a/StreamExtended/Network/IBufferedStream.cs +++ b/StreamExtended/Network/IBufferedStream.cs @@ -1,4 +1,5 @@ -using System.Threading.Tasks; +using System.Threading; +using System.Threading.Tasks; namespace StreamExtended.Network { @@ -6,10 +7,10 @@ public interface IBufferedStream { bool DataAvailable { get; } - Task FillBufferAsync(); + Task FillBufferAsync(CancellationToken cancellationToken); byte ReadByteFromBuffer(); - Task ReadAsync(byte[] buffer, int offset, int count); + Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken); } } \ No newline at end of file diff --git a/StreamExtended/Network/ServerHelloAlpnAdderStream.cs b/StreamExtended/Network/ServerHelloAlpnAdderStream.cs index 52aed6b..746f564 100644 --- a/StreamExtended/Network/ServerHelloAlpnAdderStream.cs +++ b/StreamExtended/Network/ServerHelloAlpnAdderStream.cs @@ -1,5 +1,6 @@ using System.Diagnostics; using System.IO; +using System.Threading; namespace StreamExtended.Network { @@ -47,7 +48,8 @@ public override void Write(byte[] buffer, int offset, int count) var ms = new MemoryStream(buffer, offset, count); //this can be non async, because reads from a memory stream - var serverHello = SslTools.PeekServerHello(new CustomBufferedStream(ms, (int)ms.Length)).Result; + var cts = new CancellationTokenSource(); + var serverHello = SslTools.PeekServerHello(new CustomBufferedStream(ms, (int)ms.Length), cts.Token).Result; if (serverHello != null) { // 0x00 0x10: ALPN identifier diff --git a/StreamExtended/SslTools.cs b/StreamExtended/SslTools.cs index 14c6de9..a68f457 100644 --- a/StreamExtended/SslTools.cs +++ b/StreamExtended/SslTools.cs @@ -2,24 +2,25 @@ using StreamExtended.Network; using System.Collections.Generic; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace StreamExtended { public class SslTools { - public static async Task IsClientHello(CustomBufferedStream stream) + public static async Task IsClientHello(CustomBufferedStream stream, CancellationToken cancellationToken) { - var clientHello = await PeekClientHello(stream); + var clientHello = await PeekClientHello(stream, cancellationToken); return clientHello != null; } - public static async Task PeekClientHello(CustomBufferedStream clientStream) + public static async Task PeekClientHello(CustomBufferedStream clientStream, CancellationToken cancellationToken) { //detects the HTTPS ClientHello message as it is described in the following url: //https://stackoverflow.com/questions/3897883/how-to-detect-an-incoming-ssl-https-handshake-ssl-wire-format - int recordType = await clientStream.PeekByteAsync(0); + int recordType = await clientStream.PeekByteAsync(0, cancellationToken); if (recordType == -1) { return null; @@ -31,7 +32,7 @@ public static async Task PeekClientHello(CustomBufferedStream c var peekStream = new CustomBufferedPeekStream(clientStream, 1); // length value + minimum length - if (!await peekStream.EnsureBufferLength(10)) + if (!await peekStream.EnsureBufferLength(10, cancellationToken)) { return null; } @@ -56,7 +57,7 @@ public static async Task PeekClientHello(CustomBufferedStream c int sessionIdLength = peekStream.ReadInt16(); int randomLength = peekStream.ReadInt16(); - if (!await peekStream.EnsureBufferLength(ciphersCount * 3 + sessionIdLength + randomLength)) + if (!await peekStream.EnsureBufferLength(ciphersCount * 3 + sessionIdLength + randomLength, cancellationToken)) { return null; } @@ -89,7 +90,7 @@ public static async Task PeekClientHello(CustomBufferedStream c //should contain at least 43 bytes // 2 version + 2 length + 1 type + 3 length(?) + 2 version + 32 random + 1 sessionid length - if (!await peekStream.EnsureBufferLength(43)) + if (!await peekStream.EnsureBufferLength(43, cancellationToken)) { return null; } @@ -115,7 +116,7 @@ public static async Task PeekClientHello(CustomBufferedStream c length = peekStream.ReadByte(); // sessionid + 2 ciphersData length - if (!await peekStream.EnsureBufferLength(length + 2)) + if (!await peekStream.EnsureBufferLength(length + 2, cancellationToken)) { return null; } @@ -125,7 +126,7 @@ public static async Task PeekClientHello(CustomBufferedStream c length = peekStream.ReadInt16(); // ciphersData + compressionData length - if (!await peekStream.EnsureBufferLength(length + 1)) + if (!await peekStream.EnsureBufferLength(length + 1, cancellationToken)) { return null; } @@ -144,7 +145,7 @@ public static async Task PeekClientHello(CustomBufferedStream c } // compressionData - if (!await peekStream.EnsureBufferLength(length)) + if (!await peekStream.EnsureBufferLength(length, cancellationToken)) { return null; } @@ -153,7 +154,7 @@ public static async Task PeekClientHello(CustomBufferedStream c int extenstionsStartPosition = peekStream.Position; - var extensions = await ReadExtensions(majorVersion, minorVersion, peekStream); + var extensions = await ReadExtensions(majorVersion, minorVersion, peekStream, cancellationToken); var clientHelloInfo = new ClientHelloInfo { @@ -175,16 +176,16 @@ public static async Task PeekClientHello(CustomBufferedStream c return null; } - private static async Task> ReadExtensions(int majorVersion, int minorVersion, CustomBufferedPeekStream peekStream) + private static async Task> ReadExtensions(int majorVersion, int minorVersion, CustomBufferedPeekStream peekStream, CancellationToken cancellationToken) { Dictionary extensions = null; if (majorVersion > 3 || majorVersion == 3 && minorVersion >= 1) { - if (await peekStream.EnsureBufferLength(2)) + if (await peekStream.EnsureBufferLength(2, cancellationToken)) { int extensionsLength = peekStream.ReadInt16(); - if (await peekStream.EnsureBufferLength(extensionsLength)) + if (await peekStream.EnsureBufferLength(extensionsLength, cancellationToken)) { extensions = new Dictionary(); int idx = 0; @@ -204,18 +205,18 @@ private static async Task> ReadExtensions(int m return extensions; } - public static async Task IsServerHello(CustomBufferedStream stream) + public static async Task IsServerHello(CustomBufferedStream stream, CancellationToken cancellationToken) { - var serverHello = await PeekServerHello(stream); + var serverHello = await PeekServerHello(stream, cancellationToken); return serverHello != null; } - public static async Task PeekServerHello(CustomBufferedStream serverStream) + public static async Task PeekServerHello(CustomBufferedStream serverStream, CancellationToken cancellationToken) { //detects the HTTPS ClientHello message as it is described in the following url: //https://stackoverflow.com/questions/3897883/how-to-detect-an-incoming-ssl-https-handshake-ssl-wire-format - int recordType = await serverStream.PeekByteAsync(0); + int recordType = await serverStream.PeekByteAsync(0, cancellationToken); if (recordType == -1) { return null; @@ -228,7 +229,7 @@ public static async Task PeekServerHello(CustomBufferedStream s var peekStream = new CustomBufferedPeekStream(serverStream, 1); // length value + minimum length - if (!await peekStream.EnsureBufferLength(39)) + if (!await peekStream.EnsureBufferLength(39, cancellationToken)) { return null; } @@ -250,7 +251,7 @@ public static async Task PeekServerHello(CustomBufferedStream s int minorVersion = peekStream.ReadByte(); // 32 bytes random + 1 byte sessionId + 2 bytes cipherSuite - if (!await peekStream.EnsureBufferLength(35)) + if (!await peekStream.EnsureBufferLength(35, cancellationToken)) { return null; } @@ -278,7 +279,7 @@ public static async Task PeekServerHello(CustomBufferedStream s //should contain at least 43 bytes // 2 version + 2 length + 1 type + 3 length(?) + 2 version + 32 random + 1 sessionid length - if (!await peekStream.EnsureBufferLength(43)) + if (!await peekStream.EnsureBufferLength(43, cancellationToken)) { return null; } @@ -304,7 +305,7 @@ public static async Task PeekServerHello(CustomBufferedStream s length = peekStream.ReadByte(); // sessionid + cipherSuite + compressionMethod - if (!await peekStream.EnsureBufferLength(length + 2 + 1)) + if (!await peekStream.EnsureBufferLength(length + 2 + 1, cancellationToken)) { return null; } @@ -316,7 +317,7 @@ public static async Task PeekServerHello(CustomBufferedStream s int extenstionsStartPosition = peekStream.Position; - var extensions = await ReadExtensions(majorVersion, minorVersion, peekStream); + var extensions = await ReadExtensions(majorVersion, minorVersion, peekStream, cancellationToken); //var rawBytes = new CustomBufferedPeekStream(serverStream).ReadBytes(peekStream.Position);