Skip to content

Commit

Permalink
Use value tasks for async network stream (closes #4)
Browse files Browse the repository at this point in the history
  • Loading branch information
solyutor committed Mar 12, 2020
1 parent 1f567d0 commit ffa9e07
Show file tree
Hide file tree
Showing 12 changed files with 700 additions and 111 deletions.
36 changes: 33 additions & 3 deletions src/SolarWind.Tests/TestContext.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,41 @@
using System;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Console;

namespace Codestellation.SolarWind.Tests
{
public static class TestContext
{
public static readonly ILoggerFactory LoggerFactory =
new LoggerFactory(new ILoggerProvider[] {new ConsoleLoggerProvider((s, level) => true, false)});
public static readonly ILoggerFactory LoggerFactory = new ConsoleLoggerFactory();

public class ConsoleLoggerFactory : ILoggerFactory
{
public void Dispose() => throw new NotImplementedException();

public ILogger CreateLogger(string categoryName)
=> ConsoleLogger.Instance;

public void AddProvider(ILoggerProvider provider) => throw new NotImplementedException();
}

public class ConsoleLogger : ILogger
{
public static readonly ConsoleLogger Instance = new ConsoleLogger();

private ConsoleLogger()
{
}

public void Log<TState>(
LogLevel logLevel,
EventId eventId,
TState state,
Exception exception,
Func<TState, Exception, string> formatter)
=> Console.WriteLine(formatter(state, exception));

public bool IsEnabled(LogLevel logLevel) => true;

public IDisposable BeginScope<TState>(TState state) => throw new NotImplementedException();
}
}
}
151 changes: 96 additions & 55 deletions src/SolarWind/Internals/AsyncNetworkStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,39 @@ namespace Codestellation.SolarWind.Internals
// so currently they are overridden for compatibility reasons
public class AsyncNetworkStream : NetworkStream
{
#if NETSTANDARD2_0
/*private CompletionSourceAsyncEventArgs _receiveArgs;
private CompletionSourceAsyncEventArgs _sendArgs;*/
#endif
public Socket UnderlyingSocket => Socket;

public AsyncNetworkStream(Socket socket) : base(socket, true)
{
#if NETSTANDARD2_0
/*_receiveArgs = SocketEventArgsPool.Instance.Get();
_sendArgs = SocketEventArgsPool.Instance.Get();*/
#endif
}

#if NETSTANDARD2_0
protected override void Dispose(bool disposing)
{
/*CompletionSourceAsyncEventArgs receiveArgs = Interlocked.Exchange(ref _receiveArgs, null);
if (receiveArgs != null)
{
SocketEventArgsPool.Instance.Return(receiveArgs);
}
CompletionSourceAsyncEventArgs sendArgs = Interlocked.Exchange(ref _sendArgs, null);
if (sendArgs != null)
{
SocketEventArgsPool.Instance.Return(sendArgs);
}*/
}

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellation)
=> ReadAsync(new Memory<byte>(buffer, offset, count), cancellation).AsTask();
//ReadAsync(new Memory<byte>(buffer, offset, count), cancellation).AsTask();
=> throw new NotSupportedException();

//See comments at the top

Expand All @@ -29,14 +53,15 @@ public async ValueTask<int> ReadAsync(Memory<byte> to, CancellationToken cancell
{
throw new InvalidOperationException("Non array base memory is supported for .net core 2.1+ only");
}

try
{
if (TryReceiveSyncNonBlock(segment, out int received))
if (!TryReceiveSyncNonBlock(segment, out var received))
{
return received;
received = await ReceiveAsync(segment).ConfigureAwait(false);
}

return await ReceiveAsync(segment).ConfigureAwait(false);
return received;
}
catch (Exception ex) when (ex is SocketException || ex is ObjectDisposedException)
{
Expand All @@ -46,36 +71,46 @@ public async ValueTask<int> ReadAsync(Memory<byte> to, CancellationToken cancell

private bool TryReceiveSyncNonBlock(in ArraySegment<byte> segment, out int received)
{

if (Socket.Available == 0)
var available = Socket.Available;
if (available == 0)
{
received = 0;
return false;
}

int bytesToRead = Math.Min(segment.Count, Socket.Available);
var bytesToRead = Math.Min(segment.Count, available);
received = Socket.Receive(segment.Array, segment.Offset, bytesToRead, SocketFlags.None);
return true;

}

private async Task<int> ReceiveAsync(ArraySegment<byte> segment)
private async ValueTask<int> ReceiveAsync(ArraySegment<byte> segment)
{
/*CompletionSourceAsyncEventArgs receiveArgs = Interlocked.Exchange(ref _receiveArgs, null);
var localArgs = receiveArgs != null;
if (!localArgs)
{
receiveArgs = SocketEventArgsPool.Instance.Get();
}*/
CompletionSourceAsyncEventArgs receiveArgs = SocketEventArgsPool.Instance.Get();
receiveArgs.SetBuffer(segment.Array, segment.Offset, segment.Count);

if (Socket.ReceiveAsync(receiveArgs))
{
await receiveArgs.Task.ConfigureAwait(false);
}

var source = new TaskCompletionSource<int>();
var args = new SocketAsyncEventArgs {UserToken = source};
args.Completed += HandleAsyncResult;
args.SetBuffer(segment.Array, segment.Offset, segment.Count);
var transferred = receiveArgs.BytesTransferred;

if (Socket.ReceiveAsync(args))
/*if (localArgs)
{
return await source.Task.ConfigureAwait(false);
_receiveArgs = receiveArgs;
}
else
{
SocketEventArgsPool.Instance.Return(receiveArgs);
}*/
SocketEventArgsPool.Instance.Return(receiveArgs);

int transferred = args.BytesTransferred;
args.Completed -= HandleAsyncResult;
args.UserToken = null;
args.Dispose();

// Zero transferred bytes means connection has been closed at the counterpart side.
// See https://docs.microsoft.com/en-us/dotnet/api/system.net.sockets.socketasynceventargs.bytestransferred?view=netframework-4.7.2
Expand All @@ -87,8 +122,10 @@ private async Task<int> ReceiveAsync(ArraySegment<byte> segment)
return transferred;
}

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
//WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
=> throw new NotSupportedException();


//See comments at the top
public async ValueTask WriteAsync(ReadOnlyMemory<byte> from, CancellationToken cancellationToken)
Expand All @@ -100,11 +137,11 @@ public async ValueTask WriteAsync(ReadOnlyMemory<byte> from, CancellationToken c

try
{
int left = from.Length;
var left = from.Length;
var sent = 0;
while (left != 0)
{
int realOffset = segment.Offset + sent;
var realOffset = segment.Offset + sent;

if (!TrySendSyncNonBlock(ref sent, in segment, realOffset))
{
Expand All @@ -120,30 +157,43 @@ public async ValueTask WriteAsync(ReadOnlyMemory<byte> from, CancellationToken c
}
}

private async Task<int> SendAsync(ArraySegment<byte> segment, int realOffset, int left)
private async ValueTask<int> SendAsync(ArraySegment<byte> segment, int realOffset, int left)
{
var source = new TaskCompletionSource<int>();
var args = new SocketAsyncEventArgs {UserToken = source};
args.Completed += HandleAsyncResult;
/*CompletionSourceAsyncEventArgs sendArgs = Interlocked.Exchange(ref _sendArgs, null);
var localArgs = sendArgs != null;
if (!localArgs)
{
sendArgs = SocketEventArgsPool.Instance.Get();
}*/
CompletionSourceAsyncEventArgs sendArgs = SocketEventArgsPool.Instance.Get();
sendArgs.SetBuffer(segment.Array, realOffset, left);

args.SetBuffer(segment.Array, realOffset, left);
if (Socket.SendAsync(sendArgs))
{
await sendArgs.Task.ConfigureAwait(false);
}

//Operation has completed synchronously
var bytesTransferred = sendArgs.BytesTransferred;
SocketError socketError = sendArgs.SocketError;

if (Socket.SendAsync(args))
/*if (localArgs)
{
return await source.Task.ConfigureAwait(false);
_sendArgs = sendArgs;
}
else
{
SocketEventArgsPool.Instance.Return(sendArgs);
}*/

args.Completed -= HandleAsyncResult;
args.UserToken = null;
args.Dispose();
SocketEventArgsPool.Instance.Return(sendArgs);

//Operation has completed synchronously
if (args.SocketError == SocketError.Success)
if (socketError == SocketError.Success)
{
return args.BytesTransferred;
return bytesTransferred;
}

throw BuildIoException(args.SocketError);
throw BuildIoException(socketError);
}

private bool TrySendSyncNonBlock(ref int sent, in ArraySegment<byte> segment, int realOffset)
Expand All @@ -158,41 +208,32 @@ private bool TrySendSyncNonBlock(ref int sent, in ArraySegment<byte> segment, in
sent += Socket.Send(segment.Array, realOffset, segment.Count, SocketFlags.None);
return true;
}
catch (Exception ex) when(ex is SocketException || ex is ObjectDisposedException)
catch (Exception ex) when (ex is SocketException || ex is ObjectDisposedException)
{
throw new IOException("Send failed", ex);
}
}

#endif

private static void HandleAsyncResult(object sender, SocketAsyncEventArgs e)
internal static void HandleAsyncResult(object sender, SocketAsyncEventArgs e)
{
//Nullify it to prevent double usage during further callbacks
//TaskCompletionSource<int> copy = source;
//source = null;
var source = (TaskCompletionSource<int>)e.UserToken;
var args = (CompletionSourceAsyncEventArgs)e;

if (e.SocketError != SocketError.Success)
if (args.SocketError != SocketError.Success)
{
source.TrySetException(BuildIoException(e.SocketError));
args.SetException(BuildIoException(args.SocketError));
}
else if (e.BytesTransferred == 0)
else if (args.BytesTransferred == 0)
{
// Zero transferred bytes means connection has been closed at the counterpart side.
// See https://docs.microsoft.com/en-us/dotnet/api/system.net.sockets.socketasynceventargs.bytestransferred?view=netframework-4.7.2
source.TrySetException(BuildConnectionClosedException());
args.SetException(BuildConnectionClosedException());
}
else
{
source.TrySetResult(e.BytesTransferred);
args.SetResult();
}

e.UserToken = null;
e.Completed -= HandleAsyncResult;
e.Dispose();
}

#endif
private static IOException BuildConnectionClosedException() => BuildIoException(SocketError.SocketError, "The counterpart has closed the connection");

private static IOException BuildIoException(SocketError socketError, string message = "Send or receive failed")
Expand Down
38 changes: 38 additions & 0 deletions src/SolarWind/Internals/CompletionSourceAsyncEventArgs.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using System;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
using System.Threading.Tasks.Sources;
using Codestellation.SolarWind.Threading;

namespace Codestellation.SolarWind.Internals
{
internal class CompletionSourceAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource
{
private SyncValueTaskSourceCore _valueSource = new SyncValueTaskSourceCore();

public ValueTask Task
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => new ValueTask(this, _valueSource.Version);
}

public void Reset()
=> _valueSource.Reset();

public void SetException(Exception exception)
=> _valueSource.SetException(exception);

public void SetResult()
=> _valueSource.SetResult();

public ValueTaskSourceStatus GetStatus(short token)
=> _valueSource.GetStatus(token);

public void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags)
=> _valueSource.OnCompleted(continuation, state, token, flags);

public void GetResult(short token)
=> _valueSource.GetResult(token);
}
}
10 changes: 6 additions & 4 deletions src/SolarWind/Internals/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ public async ValueTask ReceiveAsync(PooledMemoryStream destination, int bytesToR
{
_readPosition = 0;
_readLength = 0;
var memory = new Memory<byte>(_readBuffer, 0, _readBuffer.Length);
_readLength = await _networkStream
.ReadAsync(_readBuffer, 0, _readBuffer.Length, cancellation)
.ReadAsync(memory, cancellation)
.ConfigureAwait(false);
available = _readLength;
}
Expand Down Expand Up @@ -106,7 +107,7 @@ public async ValueTask WriteAsync(Message message, CancellationToken cancellatio
var readFromPayload = payload.Read(_writeBuffer, _writePosition, sliceSize);

Debug.Assert(sliceSize == readFromPayload);

bytesToSend -= readFromPayload;
_writePosition += readFromPayload;
}
Expand Down Expand Up @@ -228,11 +229,12 @@ private static void ConfigureSocket(Socket socket, SolarWindHubOptions options)
socket.LingerState = new LingerOption(true, 1);
}

public Task FlushAsync(CancellationToken cancellation)
public ValueTask FlushAsync(CancellationToken cancellation)
{
var length = _writePosition;
_writePosition = 0; //Zero it here to avoid making the method async
return _networkStream.WriteAsync(_writeBuffer, 0, length, cancellation);
var memory = new ReadOnlyMemory<byte>(_writeBuffer, 0, length);
return _networkStream.WriteAsync(memory, cancellation);
}

public void Dispose()
Expand Down
Loading

0 comments on commit ffa9e07

Please sign in to comment.