Skip to content

Commit

Permalink
Use value tasks for async network stream (closes #4)
Browse files Browse the repository at this point in the history
  • Loading branch information
solyutor committed May 9, 2020
1 parent e931a11 commit 691fab9
Show file tree
Hide file tree
Showing 12 changed files with 638 additions and 118 deletions.
36 changes: 33 additions & 3 deletions src/SolarWind.Tests/TestContext.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,41 @@
using System;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Console;

namespace Codestellation.SolarWind.Tests
{
public static class TestContext
{
public static readonly ILoggerFactory LoggerFactory =
new LoggerFactory(new ILoggerProvider[] {new ConsoleLoggerProvider((s, level) => true, false)});
public static readonly ILoggerFactory LoggerFactory = new ConsoleLoggerFactory();

public class ConsoleLoggerFactory : ILoggerFactory
{
public void Dispose() => throw new NotImplementedException();

public ILogger CreateLogger(string categoryName)
=> ConsoleLogger.Instance;

public void AddProvider(ILoggerProvider provider) => throw new NotImplementedException();
}

public class ConsoleLogger : ILogger
{
public static readonly ConsoleLogger Instance = new ConsoleLogger();

private ConsoleLogger()
{
}

public void Log<TState>(
LogLevel logLevel,
EventId eventId,
TState state,
Exception exception,
Func<TState, Exception, string> formatter)
=> Console.WriteLine(formatter(state, exception));

public bool IsEnabled(LogLevel logLevel) => true;

public IDisposable BeginScope<TState>(TState state) => throw new NotImplementedException();
}
}
}
93 changes: 46 additions & 47 deletions src/SolarWind/Internals/AsyncNetworkStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ public AsyncNetworkStream(Socket socket) : base(socket, true)
}

#if NETSTANDARD2_0
//protected override void Dispose(bool disposing) => Socket.SafeDispose();

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellation)
=> ReadAsync(new Memory<byte>(buffer, offset, count), cancellation).AsTask();
//ReadAsync(new Memory<byte>(buffer, offset, count), cancellation).AsTask();
=> throw new NotSupportedException();

//See comments at the top

Expand All @@ -29,14 +32,15 @@ public async ValueTask<int> ReadAsync(Memory<byte> to, CancellationToken cancell
{
throw new InvalidOperationException("Non array base memory is supported for .net core 2.1+ only");
}

try
{
if (TryReceiveSyncNonBlock(segment, out int received))
if (!TryReceiveSyncNonBlock(segment, out int received))
{
return received;
received = await ReceiveAsync(segment).ConfigureAwait(false);
}

return await ReceiveAsync(segment).ConfigureAwait(false);
return received;
}
catch (Exception ex) when (ex is SocketException || ex is ObjectDisposedException)
{
Expand All @@ -46,34 +50,32 @@ public async ValueTask<int> ReadAsync(Memory<byte> to, CancellationToken cancell

private bool TryReceiveSyncNonBlock(in ArraySegment<byte> segment, out int received)
{

if (Socket.Available == 0)
int available = Socket.Available;
if (available == 0)
{
received = 0;
return false;
}

int bytesToRead = Math.Min(segment.Count, Socket.Available);
int bytesToRead = Math.Min(segment.Count, available);
received = Socket.Receive(segment.Array, segment.Offset, bytesToRead, SocketFlags.None);
return true;

}

private async Task<int> ReceiveAsync(ArraySegment<byte> segment)
private async ValueTask<int> ReceiveAsync(ArraySegment<byte> segment)
{
var args = new CompletionSourceAsyncEventArgs();
args.Completed += HandleAsyncResult;
args.SetBuffer(segment.Array, segment.Offset, segment.Count);
CompletionSourceAsyncEventArgs receiveArgs = SocketEventArgsPool.Instance.Get();
receiveArgs.SetBuffer(segment.Array, segment.Offset, segment.Count);

if (Socket.ReceiveAsync(args))
if (Socket.ReceiveAsync(receiveArgs))
{
return await args.CompletionSource.Task.ConfigureAwait(false);
await receiveArgs.Task.ConfigureAwait(false);
}

//UnusedCompletionSources.Push(source);
int transferred = args.BytesTransferred;
args.Completed -= HandleAsyncResult;
args.Dispose();
int transferred = receiveArgs.BytesTransferred;

SocketEventArgsPool.Instance.Return(receiveArgs);


// Zero transferred bytes means connection has been closed at the counterpart side.
// See https://docs.microsoft.com/en-us/dotnet/api/system.net.sockets.socketasynceventargs.bytestransferred?view=netframework-4.7.2
Expand All @@ -85,8 +87,10 @@ private async Task<int> ReceiveAsync(ArraySegment<byte> segment)
return transferred;
}

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
//WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
=> throw new NotSupportedException();


//See comments at the top
public async ValueTask WriteAsync(ReadOnlyMemory<byte> from, CancellationToken cancellationToken)
Expand Down Expand Up @@ -118,28 +122,28 @@ public async ValueTask WriteAsync(ReadOnlyMemory<byte> from, CancellationToken c
}
}

private async Task<int> SendAsync(ArraySegment<byte> segment, int realOffset, int left)
private async ValueTask<int> SendAsync(ArraySegment<byte> segment, int realOffset, int left)
{
var args = new CompletionSourceAsyncEventArgs();
args.Completed += HandleAsyncResult;
CompletionSourceAsyncEventArgs sendArgs = SocketEventArgsPool.Instance.Get();
sendArgs.SetBuffer(segment.Array, realOffset, left);

args.SetBuffer(segment.Array, realOffset, left);

if (Socket.SendAsync(args))
if (Socket.SendAsync(sendArgs))
{
return await args.CompletionSource.Task.ConfigureAwait(false);
await sendArgs.Task.ConfigureAwait(false);
}

args.Completed -= HandleAsyncResult;
args.Dispose();

//Operation has completed synchronously
if (args.SocketError == SocketError.Success)
int bytesTransferred = sendArgs.BytesTransferred;
SocketError socketError = sendArgs.SocketError;

SocketEventArgsPool.Instance.Return(sendArgs);

if (socketError == SocketError.Success)
{
return args.BytesTransferred;
return bytesTransferred;
}

throw BuildIoException(args.SocketError);
throw BuildIoException(socketError);
}

private bool TrySendSyncNonBlock(ref int sent, in ArraySegment<byte> segment, int realOffset)
Expand All @@ -154,37 +158,32 @@ private bool TrySendSyncNonBlock(ref int sent, in ArraySegment<byte> segment, in
sent += Socket.Send(segment.Array, realOffset, segment.Count, SocketFlags.None);
return true;
}
catch (Exception ex) when(ex is SocketException || ex is ObjectDisposedException)
catch (Exception ex) when (ex is SocketException || ex is ObjectDisposedException)
{
throw new IOException("Send failed", ex);
}
}

#endif

private static void HandleAsyncResult(object sender, SocketAsyncEventArgs e)
internal static void HandleAsyncResult(object sender, SocketAsyncEventArgs e)
{
TaskCompletionSource<int> source = ((CompletionSourceAsyncEventArgs)e).CompletionSource;
var args = (CompletionSourceAsyncEventArgs)e;

if (e.SocketError != SocketError.Success)
if (args.SocketError != SocketError.Success)
{
source.TrySetException(BuildIoException(e.SocketError));
args.SetException(BuildIoException(args.SocketError));
}
else if (e.BytesTransferred == 0)
else if (args.BytesTransferred == 0)
{
// Zero transferred bytes means connection has been closed at the counterpart side.
// See https://docs.microsoft.com/en-us/dotnet/api/system.net.sockets.socketasynceventargs.bytestransferred?view=netframework-4.7.2
source.TrySetException(BuildConnectionClosedException());
args.SetException(BuildConnectionClosedException());
}
else
{
source.TrySetResult(e.BytesTransferred);
args.SetResult();
}

e.Completed -= HandleAsyncResult;
e.Dispose();
}

#endif
private static IOException BuildConnectionClosedException() => BuildIoException(SocketError.SocketError, "The counterpart has closed the connection");

private static IOException BuildIoException(SocketError socketError, string message = "Send or receive failed")
Expand Down
41 changes: 26 additions & 15 deletions src/SolarWind/Internals/CompletionSourceAsyncEventArgs.cs
Original file line number Diff line number Diff line change
@@ -1,27 +1,38 @@
using System;
using System.Net.Sockets;
using System.Threading;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
using System.Threading.Tasks.Sources;
using Codestellation.SolarWind.Threading;

namespace Codestellation.SolarWind.Internals
{
internal class CompletionSourceAsyncEventArgs : SocketAsyncEventArgs
internal class CompletionSourceAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource
{
private TaskCompletionSource<int> _source;
private SyncValueTaskSourceCore _valueSource = new SyncValueTaskSourceCore();

public TaskCompletionSource<int> CompletionSource
public ValueTask Task
{
get
{
if (_source != null)
{
return _source;
}
// Here's possible multiple creation of TaskCompletionSource, but it's unlikely to happen;
// However it allows thread-safe assigning without locking.
Interlocked.CompareExchange(ref _source, new TaskCompletionSource<int>(), null);
return _source;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => new ValueTask(this, _valueSource.Version);
}

public void Reset()
=> _valueSource.Reset();

public void SetException(Exception exception)
=> _valueSource.SetException(exception);

public void SetResult()
=> _valueSource.SetResult();

public ValueTaskSourceStatus GetStatus(short token)
=> _valueSource.GetStatus(token);

public void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags)
=> _valueSource.OnCompleted(continuation, state, token, flags);

public void GetResult(short token)
=> _valueSource.GetResult(token);
}
}
10 changes: 6 additions & 4 deletions src/SolarWind/Internals/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ public async ValueTask ReceiveAsync(PooledMemoryStream destination, int bytesToR
{
_readPosition = 0;
_readLength = 0;
var memory = new Memory<byte>(_readBuffer, 0, _readBuffer.Length);
_readLength = await _networkStream
.ReadAsync(_readBuffer, 0, _readBuffer.Length, cancellation)
.ReadAsync(memory, cancellation)
.ConfigureAwait(false);
available = _readLength;
}
Expand Down Expand Up @@ -106,7 +107,7 @@ public async ValueTask WriteAsync(Message message, CancellationToken cancellatio
var readFromPayload = payload.Read(_writeBuffer, _writePosition, sliceSize);

Debug.Assert(sliceSize == readFromPayload);

bytesToSend -= readFromPayload;
_writePosition += readFromPayload;
}
Expand Down Expand Up @@ -228,11 +229,12 @@ private static void ConfigureSocket(Socket socket, SolarWindHubOptions options)
socket.LingerState = new LingerOption(true, 1);
}

public Task FlushAsync(CancellationToken cancellation)
public ValueTask FlushAsync(CancellationToken cancellation)
{
var length = _writePosition;
_writePosition = 0; //Zero it here to avoid making the method async
return _networkStream.WriteAsync(_writeBuffer, 0, length, cancellation);
var memory = new ReadOnlyMemory<byte>(_writeBuffer, 0, length);
return _networkStream.WriteAsync(memory, cancellation);
}

public void Dispose()
Expand Down
40 changes: 40 additions & 0 deletions src/SolarWind/Internals/LocalSocketEventArgsPool.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#if NETSTANDARD2_0
using Microsoft.Extensions.ObjectPool;

namespace Codestellation.SolarWind.Internals
{
internal class SocketEventArgsPool : DefaultObjectPool<CompletionSourceAsyncEventArgs>
{
public static readonly SocketEventArgsPool Instance = new SocketEventArgsPool();

private class Policy : IPooledObjectPolicy<CompletionSourceAsyncEventArgs>
{
public CompletionSourceAsyncEventArgs Create()
{
var result = new CompletionSourceAsyncEventArgs();
result.Completed += AsyncNetworkStream.HandleAsyncResult;
return result;
}

public bool Return(CompletionSourceAsyncEventArgs obj) => true;
}

public SocketEventArgsPool()
: base(new Policy())
{
}

public SocketEventArgsPool(int maximumRetained)
: base(new Policy(), maximumRetained)
{
}

public override CompletionSourceAsyncEventArgs Get()
{
CompletionSourceAsyncEventArgs result = base.Get();
result.Reset();
return result;
}
}
}
#endif
Loading

0 comments on commit 691fab9

Please sign in to comment.