diff --git a/src/SolarWind.Tests/TestContext.cs b/src/SolarWind.Tests/TestContext.cs index c09797e..3d9ed99 100644 --- a/src/SolarWind.Tests/TestContext.cs +++ b/src/SolarWind.Tests/TestContext.cs @@ -1,11 +1,41 @@ +using System; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Console; namespace Codestellation.SolarWind.Tests { public static class TestContext { - public static readonly ILoggerFactory LoggerFactory = - new LoggerFactory(new ILoggerProvider[] {new ConsoleLoggerProvider((s, level) => true, false)}); + public static readonly ILoggerFactory LoggerFactory = new ConsoleLoggerFactory(); + + public class ConsoleLoggerFactory : ILoggerFactory + { + public void Dispose() => throw new NotImplementedException(); + + public ILogger CreateLogger(string categoryName) + => ConsoleLogger.Instance; + + public void AddProvider(ILoggerProvider provider) => throw new NotImplementedException(); + } + + public class ConsoleLogger : ILogger + { + public static readonly ConsoleLogger Instance = new ConsoleLogger(); + + private ConsoleLogger() + { + } + + public void Log( + LogLevel logLevel, + EventId eventId, + TState state, + Exception exception, + Func formatter) + => Console.WriteLine(formatter(state, exception)); + + public bool IsEnabled(LogLevel logLevel) => true; + + public IDisposable BeginScope(TState state) => throw new NotImplementedException(); + } } } \ No newline at end of file diff --git a/src/SolarWind/Internals/AsyncNetworkStream.cs b/src/SolarWind/Internals/AsyncNetworkStream.cs index 4ffd54d..b956fd8 100644 --- a/src/SolarWind/Internals/AsyncNetworkStream.cs +++ b/src/SolarWind/Internals/AsyncNetworkStream.cs @@ -11,15 +11,39 @@ namespace Codestellation.SolarWind.Internals // so currently they are overridden for compatibility reasons public class AsyncNetworkStream : NetworkStream { +#if NETSTANDARD2_0 + /*private CompletionSourceAsyncEventArgs _receiveArgs; + private CompletionSourceAsyncEventArgs _sendArgs;*/ +#endif public Socket UnderlyingSocket => Socket; public AsyncNetworkStream(Socket socket) : base(socket, true) { +#if NETSTANDARD2_0 + /*_receiveArgs = SocketEventArgsPool.Instance.Get(); + _sendArgs = SocketEventArgsPool.Instance.Get();*/ +#endif } #if NETSTANDARD2_0 + protected override void Dispose(bool disposing) + { + /*CompletionSourceAsyncEventArgs receiveArgs = Interlocked.Exchange(ref _receiveArgs, null); + if (receiveArgs != null) + { + SocketEventArgsPool.Instance.Return(receiveArgs); + } + + CompletionSourceAsyncEventArgs sendArgs = Interlocked.Exchange(ref _sendArgs, null); + if (sendArgs != null) + { + SocketEventArgsPool.Instance.Return(sendArgs); + }*/ + } + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellation) - => ReadAsync(new Memory(buffer, offset, count), cancellation).AsTask(); + //ReadAsync(new Memory(buffer, offset, count), cancellation).AsTask(); + => throw new NotSupportedException(); //See comments at the top @@ -29,14 +53,15 @@ public async ValueTask ReadAsync(Memory to, CancellationToken cancell { throw new InvalidOperationException("Non array base memory is supported for .net core 2.1+ only"); } + try { - if (TryReceiveSyncNonBlock(segment, out int received)) + if (!TryReceiveSyncNonBlock(segment, out var received)) { - return received; + received = await ReceiveAsync(segment).ConfigureAwait(false); } - return await ReceiveAsync(segment).ConfigureAwait(false); + return received; } catch (Exception ex) when (ex is SocketException || ex is ObjectDisposedException) { @@ -46,36 +71,46 @@ public async ValueTask ReadAsync(Memory to, CancellationToken cancell private bool TryReceiveSyncNonBlock(in ArraySegment segment, out int received) { - - if (Socket.Available == 0) + var available = Socket.Available; + if (available == 0) { received = 0; return false; } - int bytesToRead = Math.Min(segment.Count, Socket.Available); + var bytesToRead = Math.Min(segment.Count, available); received = Socket.Receive(segment.Array, segment.Offset, bytesToRead, SocketFlags.None); return true; - } - private async Task ReceiveAsync(ArraySegment segment) + private async ValueTask ReceiveAsync(ArraySegment segment) { + /*CompletionSourceAsyncEventArgs receiveArgs = Interlocked.Exchange(ref _receiveArgs, null); + var localArgs = receiveArgs != null; + if (!localArgs) + { + receiveArgs = SocketEventArgsPool.Instance.Get(); + }*/ + CompletionSourceAsyncEventArgs receiveArgs = SocketEventArgsPool.Instance.Get(); + receiveArgs.SetBuffer(segment.Array, segment.Offset, segment.Count); + + if (Socket.ReceiveAsync(receiveArgs)) + { + await receiveArgs.Task.ConfigureAwait(false); + } - var source = new TaskCompletionSource(); - var args = new SocketAsyncEventArgs {UserToken = source}; - args.Completed += HandleAsyncResult; - args.SetBuffer(segment.Array, segment.Offset, segment.Count); + var transferred = receiveArgs.BytesTransferred; - if (Socket.ReceiveAsync(args)) + /*if (localArgs) { - return await source.Task.ConfigureAwait(false); + _receiveArgs = receiveArgs; } + else + { + SocketEventArgsPool.Instance.Return(receiveArgs); + }*/ + SocketEventArgsPool.Instance.Return(receiveArgs); - int transferred = args.BytesTransferred; - args.Completed -= HandleAsyncResult; - args.UserToken = null; - args.Dispose(); // Zero transferred bytes means connection has been closed at the counterpart side. // See https://docs.microsoft.com/en-us/dotnet/api/system.net.sockets.socketasynceventargs.bytestransferred?view=netframework-4.7.2 @@ -87,8 +122,10 @@ private async Task ReceiveAsync(ArraySegment segment) return transferred; } - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => - WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + //WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); + => throw new NotSupportedException(); + //See comments at the top public async ValueTask WriteAsync(ReadOnlyMemory from, CancellationToken cancellationToken) @@ -100,11 +137,11 @@ public async ValueTask WriteAsync(ReadOnlyMemory from, CancellationToken c try { - int left = from.Length; + var left = from.Length; var sent = 0; while (left != 0) { - int realOffset = segment.Offset + sent; + var realOffset = segment.Offset + sent; if (!TrySendSyncNonBlock(ref sent, in segment, realOffset)) { @@ -120,30 +157,43 @@ public async ValueTask WriteAsync(ReadOnlyMemory from, CancellationToken c } } - private async Task SendAsync(ArraySegment segment, int realOffset, int left) + private async ValueTask SendAsync(ArraySegment segment, int realOffset, int left) { - var source = new TaskCompletionSource(); - var args = new SocketAsyncEventArgs {UserToken = source}; - args.Completed += HandleAsyncResult; + /*CompletionSourceAsyncEventArgs sendArgs = Interlocked.Exchange(ref _sendArgs, null); + var localArgs = sendArgs != null; + if (!localArgs) + { + sendArgs = SocketEventArgsPool.Instance.Get(); + }*/ + CompletionSourceAsyncEventArgs sendArgs = SocketEventArgsPool.Instance.Get(); + sendArgs.SetBuffer(segment.Array, realOffset, left); - args.SetBuffer(segment.Array, realOffset, left); + if (Socket.SendAsync(sendArgs)) + { + await sendArgs.Task.ConfigureAwait(false); + } + + //Operation has completed synchronously + var bytesTransferred = sendArgs.BytesTransferred; + SocketError socketError = sendArgs.SocketError; - if (Socket.SendAsync(args)) + /*if (localArgs) { - return await source.Task.ConfigureAwait(false); + _sendArgs = sendArgs; } + else + { + SocketEventArgsPool.Instance.Return(sendArgs); + }*/ - args.Completed -= HandleAsyncResult; - args.UserToken = null; - args.Dispose(); + SocketEventArgsPool.Instance.Return(sendArgs); - //Operation has completed synchronously - if (args.SocketError == SocketError.Success) + if (socketError == SocketError.Success) { - return args.BytesTransferred; + return bytesTransferred; } - throw BuildIoException(args.SocketError); + throw BuildIoException(socketError); } private bool TrySendSyncNonBlock(ref int sent, in ArraySegment segment, int realOffset) @@ -158,41 +208,32 @@ private bool TrySendSyncNonBlock(ref int sent, in ArraySegment segment, in sent += Socket.Send(segment.Array, realOffset, segment.Count, SocketFlags.None); return true; } - catch (Exception ex) when(ex is SocketException || ex is ObjectDisposedException) + catch (Exception ex) when (ex is SocketException || ex is ObjectDisposedException) { throw new IOException("Send failed", ex); } } -#endif - - private static void HandleAsyncResult(object sender, SocketAsyncEventArgs e) + internal static void HandleAsyncResult(object sender, SocketAsyncEventArgs e) { - //Nullify it to prevent double usage during further callbacks - //TaskCompletionSource copy = source; - //source = null; - var source = (TaskCompletionSource)e.UserToken; + var args = (CompletionSourceAsyncEventArgs)e; - if (e.SocketError != SocketError.Success) + if (args.SocketError != SocketError.Success) { - source.TrySetException(BuildIoException(e.SocketError)); + args.SetException(BuildIoException(args.SocketError)); } - else if (e.BytesTransferred == 0) + else if (args.BytesTransferred == 0) { // Zero transferred bytes means connection has been closed at the counterpart side. // See https://docs.microsoft.com/en-us/dotnet/api/system.net.sockets.socketasynceventargs.bytestransferred?view=netframework-4.7.2 - source.TrySetException(BuildConnectionClosedException()); + args.SetException(BuildConnectionClosedException()); } else { - source.TrySetResult(e.BytesTransferred); + args.SetResult(); } - - e.UserToken = null; - e.Completed -= HandleAsyncResult; - e.Dispose(); } - +#endif private static IOException BuildConnectionClosedException() => BuildIoException(SocketError.SocketError, "The counterpart has closed the connection"); private static IOException BuildIoException(SocketError socketError, string message = "Send or receive failed") diff --git a/src/SolarWind/Internals/CompletionSourceAsyncEventArgs.cs b/src/SolarWind/Internals/CompletionSourceAsyncEventArgs.cs new file mode 100644 index 0000000..3fcbbcd --- /dev/null +++ b/src/SolarWind/Internals/CompletionSourceAsyncEventArgs.cs @@ -0,0 +1,38 @@ +using System; +using System.Net.Sockets; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; +using System.Threading.Tasks.Sources; +using Codestellation.SolarWind.Threading; + +namespace Codestellation.SolarWind.Internals +{ + internal class CompletionSourceAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource + { + private SyncValueTaskSourceCore _valueSource = new SyncValueTaskSourceCore(); + + public ValueTask Task + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => new ValueTask(this, _valueSource.Version); + } + + public void Reset() + => _valueSource.Reset(); + + public void SetException(Exception exception) + => _valueSource.SetException(exception); + + public void SetResult() + => _valueSource.SetResult(); + + public ValueTaskSourceStatus GetStatus(short token) + => _valueSource.GetStatus(token); + + public void OnCompleted(Action continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) + => _valueSource.OnCompleted(continuation, state, token, flags); + + public void GetResult(short token) + => _valueSource.GetResult(token); + } +} \ No newline at end of file diff --git a/src/SolarWind/Internals/Connection.cs b/src/SolarWind/Internals/Connection.cs index 239dfc0..ecb04a5 100644 --- a/src/SolarWind/Internals/Connection.cs +++ b/src/SolarWind/Internals/Connection.cs @@ -55,8 +55,9 @@ public async ValueTask ReceiveAsync(PooledMemoryStream destination, int bytesToR { _readPosition = 0; _readLength = 0; + var memory = new Memory(_readBuffer, 0, _readBuffer.Length); _readLength = await _networkStream - .ReadAsync(_readBuffer, 0, _readBuffer.Length, cancellation) + .ReadAsync(memory, cancellation) .ConfigureAwait(false); available = _readLength; } @@ -106,7 +107,7 @@ public async ValueTask WriteAsync(Message message, CancellationToken cancellatio var readFromPayload = payload.Read(_writeBuffer, _writePosition, sliceSize); Debug.Assert(sliceSize == readFromPayload); - + bytesToSend -= readFromPayload; _writePosition += readFromPayload; } @@ -228,11 +229,12 @@ private static void ConfigureSocket(Socket socket, SolarWindHubOptions options) socket.LingerState = new LingerOption(true, 1); } - public Task FlushAsync(CancellationToken cancellation) + public ValueTask FlushAsync(CancellationToken cancellation) { var length = _writePosition; _writePosition = 0; //Zero it here to avoid making the method async - return _networkStream.WriteAsync(_writeBuffer, 0, length, cancellation); + var memory = new ReadOnlyMemory(_writeBuffer, 0, length); + return _networkStream.WriteAsync(memory, cancellation); } public void Dispose() diff --git a/src/SolarWind/Internals/LocalSocketEventArgsPool.cs b/src/SolarWind/Internals/LocalSocketEventArgsPool.cs new file mode 100644 index 0000000..8b211f4 --- /dev/null +++ b/src/SolarWind/Internals/LocalSocketEventArgsPool.cs @@ -0,0 +1,40 @@ +#if NETSTANDARD2_0 +using Microsoft.Extensions.ObjectPool; + +namespace Codestellation.SolarWind.Internals +{ + internal class SocketEventArgsPool : DefaultObjectPool + { + public static readonly SocketEventArgsPool Instance = new SocketEventArgsPool(); + + private class Policy : IPooledObjectPolicy + { + public CompletionSourceAsyncEventArgs Create() + { + var result = new CompletionSourceAsyncEventArgs(); + result.Completed += AsyncNetworkStream.HandleAsyncResult; + return result; + } + + public bool Return(CompletionSourceAsyncEventArgs obj) => true; + } + + public SocketEventArgsPool() + : base(new Policy()) + { + } + + public SocketEventArgsPool(int maximumRetained) + : base(new Policy(), maximumRetained) + { + } + + public override CompletionSourceAsyncEventArgs Get() + { + CompletionSourceAsyncEventArgs result = base.Get(); + result.Reset(); + return result; + } + } +} +#endif \ No newline at end of file diff --git a/src/SolarWind/Internals/PooledMemoryStream.cs b/src/SolarWind/Internals/PooledMemoryStream.cs index 31dfe03..ae7615d 100644 --- a/src/SolarWind/Internals/PooledMemoryStream.cs +++ b/src/SolarWind/Internals/PooledMemoryStream.cs @@ -229,7 +229,7 @@ public int Write(Stream from, int count) return count - left; } - public async ValueTask WriteAsync(Stream from, int count, CancellationToken cancellation) + public async ValueTask WriteAsync(AsyncNetworkStream from, int count, CancellationToken cancellation) { if (count == 0) { @@ -243,7 +243,8 @@ public async ValueTask WriteAsync(Stream from, int count, CancellationToken { MemoryMarshal.TryGetArray(GetWritableMemory(left), out ArraySegment segment); int bytesToRead = Math.Min(count, segment.Count); - lastRead = await from.ReadAsync(segment.Array, segment.Offset, bytesToRead, cancellation).ConfigureAwait(false); + var memory = new Memory(segment.Array, segment.Offset, bytesToRead); + lastRead = await from.ReadAsync(memory, cancellation).ConfigureAwait(false); left -= lastRead; _position += lastRead; } while (left != 0 && lastRead > 0); @@ -255,42 +256,5 @@ public async ValueTask WriteAsync(Stream from, int count, CancellationToken return count - left; } - - public ValueTask CopyIntoAsync(Stream destination, CancellationToken cancellation) - { - CopyInto(destination); - return new ValueTask(Task.CompletedTask); - } - - public async ValueTask CopyIntoAsync(AsyncNetworkStream destination, CancellationToken cancellation) - { - var left = (int)_length; - foreach (byte[] buffer in _buffers) - { - int bytesToCopy = Math.Min(left, buffer.Length); - var memory = new ReadOnlyMemory(buffer, 0, bytesToCopy); - await destination.WriteAsync(memory, cancellation).ConfigureAwait(ContinueOn.IOScheduler); - left -= bytesToCopy; - if (left == 0) - { - break; - } - } - } - - public void CopyInto(Stream destination) - { - var left = (int)_length; - foreach (byte[] buffer in _buffers) - { - int bytesToCopy = Math.Min(left, buffer.Length); - destination.Write(buffer, 0, bytesToCopy); - left -= bytesToCopy; - if (left == 0) - { - break; - } - } - } } } \ No newline at end of file diff --git a/src/SolarWind/SolarWind.csproj b/src/SolarWind/SolarWind.csproj index 48eff0e..14cff00 100644 --- a/src/SolarWind/SolarWind.csproj +++ b/src/SolarWind/SolarWind.csproj @@ -1,4 +1,4 @@ - + netstandard2.0;netcoreapp2.1 @@ -13,7 +13,7 @@ true Codestellation.SolarWind Codestellation.SolarWind - 7035 + 7035;1591 true $(AllowedOutputExtensionsInPackageBuildOutputFolder);.pdb latest @@ -27,7 +27,7 @@ - + diff --git a/src/SolarWind/Threading/ContinueOn.cs b/src/SolarWind/Threading/ContinueOn.cs index 106fd97..d1656cd 100644 --- a/src/SolarWind/Threading/ContinueOn.cs +++ b/src/SolarWind/Threading/ContinueOn.cs @@ -1,5 +1,3 @@ -using System.Threading.Tasks; - namespace Codestellation.SolarWind.Threading { internal static class ContinueOn @@ -7,7 +5,7 @@ internal static class ContinueOn // Actually is not true. Continuation will run on a captured context. // If the context has TaskScheduler.Current = IOScheduler it will run on it // Will skip other cases to avoid dangerous or redundant captures. - public static bool IOScheduler => TaskScheduler.Current == IOTaskScheduler.Instance; + public const bool IOScheduler = false; //TaskScheduler.Current == IOTaskScheduler.Instance; public const bool DefaultScheduler = false; } diff --git a/src/SolarWind/Threading/IOTaskScheduler.cs b/src/SolarWind/Threading/IOTaskScheduler.cs index 8b70c73..a0f32bc 100644 --- a/src/SolarWind/Threading/IOTaskScheduler.cs +++ b/src/SolarWind/Threading/IOTaskScheduler.cs @@ -1,4 +1,4 @@ - + using System; using System.Collections.Generic; using System.Linq; @@ -14,7 +14,8 @@ public sealed unsafe class IOTaskScheduler : TaskScheduler private readonly ObjectPool _workItemsPool; - private IOTaskScheduler() => _workItemsPool = new DefaultObjectPool(new WorkItemPolicy(this)); + private IOTaskScheduler() + => _workItemsPool = new DefaultObjectPool(new WorkItemPolicy(this)); protected override void QueueTask(Task task) { @@ -23,9 +24,11 @@ protected override void QueueTask(Task task) ThreadPool.UnsafeQueueNativeOverlapped(wi.PNOlap); } - protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued) => TryExecuteTask(task); + protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued) + => TryExecuteTask(task); - protected override IEnumerable GetScheduledTasks() => Enumerable.Empty(); + protected override IEnumerable GetScheduledTasks() + => Enumerable.Empty(); private class WorkItem { diff --git a/src/SolarWind/Threading/ManualResetValueTaskSourceCore.cs b/src/SolarWind/Threading/ManualResetValueTaskSourceCore.cs new file mode 100644 index 0000000..e4e2d12 --- /dev/null +++ b/src/SolarWind/Threading/ManualResetValueTaskSourceCore.cs @@ -0,0 +1,294 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics; +using System.Runtime.ExceptionServices; +using System.Runtime.InteropServices; + +//TODO: Do not compile it for .netstandard 2.1 /.netcoreapp 3.0+ +namespace System.Threading.Tasks.Sources +{ + /// Provides the core logic for implementing a manual-reset or . + /// + [StructLayout(LayoutKind.Auto)] + public struct ManualResetValueTaskSourceCore + { + /// + /// The callback to invoke when the operation completes if was called before the operation completed, + /// or if the operation completed before a callback was supplied, + /// or null if a callback hasn't yet been provided and the operation hasn't yet completed. + /// + private Action _continuation; + + /// State to pass to . + private object _continuationState; + + /// to flow to the callback, or null if no flowing is required. + private ExecutionContext _executionContext; + + /// + /// A "captured" or with which to invoke the callback, + /// or null if no special context is required. + /// + private object _capturedContext; + + /// Whether the current operation has completed. + private bool _completed; + + /// The result with which the operation succeeded, or the default value if it hasn't yet completed or failed. + private TResult _result; + + /// The exception with which the operation failed, or null if it hasn't yet completed or completed successfully. + private ExceptionDispatchInfo _error; + + /// The current version of this value, used to help prevent misuse. + private short _version; + + /// Gets or sets whether to force continuations to run asynchronously. + /// Continuations may run asynchronously if this is false, but they'll never run synchronously if this is true. + public bool RunContinuationsAsynchronously { get; set; } + + /// Resets to prepare for the next operation. + public void Reset() + { + // Reset/update state for the next use/await of this instance. + _version++; + _completed = false; + _result = default; + _error = null; + _executionContext = null; + _capturedContext = null; + _continuation = null; + _continuationState = null; + } + + /// Completes with a successful result. + /// The result. + public void SetResult(TResult result) + { + _result = result; + SignalCompletion(); + } + + /// Completes with an error. + /// The exception. + public void SetException(Exception error) + { + _error = ExceptionDispatchInfo.Capture(error); + SignalCompletion(); + } + + /// Gets the operation version. + public short Version => _version; + + /// Gets the status of the operation. + /// Opaque value that was provided to the 's constructor. + public ValueTaskSourceStatus GetStatus(short token) + { + ValidateToken(token); + return + _continuation == null || !_completed ? ValueTaskSourceStatus.Pending : + _error == null ? ValueTaskSourceStatus.Succeeded : + _error.SourceException is OperationCanceledException ? ValueTaskSourceStatus.Canceled : + ValueTaskSourceStatus.Faulted; + } + + /// Gets the result of the operation. + /// Opaque value that was provided to the 's constructor. + public TResult GetResult(short token) + { + ValidateToken(token); + if (!_completed) + { + ValueTaskSourceHelper.ThrowInvalidOperationException("Getting result on non-completed task"); + } + + _error?.Throw(); + return _result; + } + + /// Schedules the continuation action for this operation. + /// The continuation to invoke when the operation has completed. + /// The state object to pass to when it's invoked. + /// Opaque value that was provided to the 's constructor. + /// The flags describing the behavior of the continuation. + public void OnCompleted(Action continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) + { + if (continuation == null) + { + throw new ArgumentNullException(nameof(continuation)); + } + + ValidateToken(token); + + if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0) + { + _executionContext = ExecutionContext.Capture(); + } + + if ((flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) != 0) + { + SynchronizationContext sc = SynchronizationContext.Current; + if (sc != null && sc.GetType() != typeof(SynchronizationContext)) + { + _capturedContext = sc; + } + else + { + TaskScheduler ts = TaskScheduler.Current; + if (ts != TaskScheduler.Default) + { + _capturedContext = ts; + } + } + } + + // We need to set the continuation state before we swap in the delegate, so that + // if there's a race between this and SetResult/Exception and SetResult/Exception + // sees the _continuation as non-null, it'll be able to invoke it with the state + // stored here. However, this also means that if this is used incorrectly (e.g. + // awaited twice concurrently), _continuationState might get erroneously overwritten. + // To minimize the chances of that, we check preemptively whether _continuation + // is already set to something other than the completion sentinel. + + object oldContinuation = _continuation; + if (oldContinuation == null) + { + _continuationState = state; + oldContinuation = Interlocked.CompareExchange(ref _continuation, continuation, null); + } + + if (oldContinuation != null) + { + // Operation already completed, so we need to queue the supplied callback. + if (!ReferenceEquals(oldContinuation, ValueTaskSourceHelper.s_sentinel)) + { + ValueTaskSourceHelper.ThrowInvalidOperationException("Something went wrong"); + } + + switch (_capturedContext) + { + case null: + { + WaitCallback callback = s => continuation(s); + if (_executionContext != null) + { + ThreadPool.QueueUserWorkItem(callback, state); + } + else + { + ThreadPool.UnsafeQueueUserWorkItem(callback, state); + } + } + + break; + + case SynchronizationContext sc: + sc.Post(s => + { + var tuple = (Tuple, object>)s!; + tuple.Item1(tuple.Item2); + }, Tuple.Create(continuation, state)); + break; + + case TaskScheduler ts: + Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, ts); + break; + } + } + } + + /// Ensures that the specified token matches the current version. + /// The token supplied by . + private void ValidateToken(short token) + { + if (token != _version) + { + ValueTaskSourceHelper.ThrowInvalidOperationException("Version mismatch. Possible double awaiting of the value task"); + } + } + + /// Signals that the operation has completed. Invoked after the result or error has been set. + private void SignalCompletion() + { + if (_completed) + { + ValueTaskSourceHelper.ThrowInvalidOperationException("Signaling completion on non-completed task"); + } + + _completed = true; + + if (_continuation != null || Interlocked.CompareExchange(ref _continuation, ValueTaskSourceHelper.s_sentinel, null) != null) + { + if (_executionContext != null) + { + //TODO: Looks like a hack. Consider dropping this path completely. + ExecutionContext.Run( + _executionContext, + s => ((ManualResetValueTaskSourceCore)s).InvokeContinuation(), + _continuationState); + //ExecutionContext.RunInternal( + // _executionContext, + // (ref ManualResetValueTaskSourceCore s) => s.InvokeContinuation(), + // ref this); + } + else + { + InvokeContinuation(); + } + } + } + + /// + /// Invokes the continuation with the appropriate captured context / scheduler. + /// This assumes that if is not null we're already + /// running within that . + /// + private void InvokeContinuation() + { + Debug.Assert(_continuation != null); + + switch (_capturedContext) + { + case null: + { + //TODO: Completely undesired path due to redundant allocations + if (RunContinuationsAsynchronously) + { + Action continuation = _continuation; + WaitCallback callback = s => continuation(s); + if (_executionContext != null) + { + ThreadPool.QueueUserWorkItem(callback, _continuationState); + } + else + { + ThreadPool.UnsafeQueueUserWorkItem(callback, _continuationState); + } + } + else + { + _continuation(_continuationState); + } + } + break; + + case SynchronizationContext sc: + sc.Post(s => + { + var state = (Tuple, object>)s!; + state.Item1(state.Item2); + }, Tuple.Create(_continuation, _continuationState)); + break; + + case TaskScheduler ts: + { + Task.Factory.StartNew(_continuation, _continuationState, CancellationToken.None, TaskCreationOptions.DenyChildAttach, ts); + } + break; + } + } + + } +} \ No newline at end of file diff --git a/src/SolarWind/Threading/SyncValueTaskSourceCore.cs b/src/SolarWind/Threading/SyncValueTaskSourceCore.cs new file mode 100644 index 0000000..fe363d6 --- /dev/null +++ b/src/SolarWind/Threading/SyncValueTaskSourceCore.cs @@ -0,0 +1,161 @@ +using System; +using System.Diagnostics; +using System.Runtime.ExceptionServices; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Tasks.Sources; + +namespace Codestellation.SolarWind.Threading +{ + [StructLayout(LayoutKind.Auto)] + public struct SyncValueTaskSourceCore + { + /// + /// The callback to invoke when the operation completes if was called before the operation completed, + /// or if the operation completed before a callback was supplied, + /// or null if a callback hasn't yet been provided and the operation hasn't yet completed. + /// + private Action _continuation; + + /// State to pass to . + private object _continuationState; + + + /// Whether the current operation has completed. + private bool _completed; + + /// The exception with which the operation failed, or null if it hasn't yet completed or completed successfully. + private ExceptionDispatchInfo _error; + + /// The current version of this value, used to help prevent misuse. + private short _version; + + + /// Resets to prepare for the next operation. + public void Reset() + { + // Reset/update state for the next use/await of this instance. + _version++; + _completed = false; + _error = null; + _continuation = null; + _continuationState = null; + } + + /// Completes with a successful result. + public void SetResult() => SignalCompletion(); + + /// Completes with an error. + /// The exception. + public void SetException(Exception error) + { + _error = ExceptionDispatchInfo.Capture(error); + SignalCompletion(); + } + + /// Gets the operation version. + public short Version => _version; + + /// Gets the status of the operation. + /// Opaque value that was provided to the 's constructor. + public ValueTaskSourceStatus GetStatus(short token) + { + ValidateToken(token); + return + _continuation == null || !_completed ? ValueTaskSourceStatus.Pending : + _error == null ? ValueTaskSourceStatus.Succeeded : + _error.SourceException is OperationCanceledException ? ValueTaskSourceStatus.Canceled : + ValueTaskSourceStatus.Faulted; + } + + /// Gets the result of the operation. + /// Opaque value that was provided to the 's constructor. + public void GetResult(short token) + { + ValidateToken(token); + if (!_completed) + { + ValueTaskSourceHelper.ThrowInvalidOperationException("Getting result on non-completed task"); + } + + _error?.Throw(); + } + + /// Schedules the continuation action for this operation. + /// The continuation to invoke when the operation has completed. + /// The state object to pass to when it's invoked. + /// Opaque value that was provided to the 's constructor. + /// The flags describing the behavior of the continuation. + public void OnCompleted(Action continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) + { + if (continuation == null) + { + throw new ArgumentNullException(nameof(continuation)); + } + + ValidateToken(token); + + // We need to set the continuation state before we swap in the delegate, so that + // if there's a race between this and SetResult/Exception and SetResult/Exception + // sees the _continuation as non-null, it'll be able to invoke it with the state + // stored here. However, this also means that if this is used incorrectly (e.g. + // awaited twice concurrently), _continuationState might get erroneously overwritten. + // To minimize the chances of that, we check preemptively whether _continuation + // is already set to something other than the completion sentinel. + + object oldContinuation = _continuation; + if (oldContinuation == null) + { + _continuationState = state; + oldContinuation = Interlocked.CompareExchange(ref _continuation, continuation, null); + } + + if (oldContinuation != null) + { + // Operation already completed, so we need to queue the supplied callback. + if (!ReferenceEquals(oldContinuation, ValueTaskSourceHelper.s_sentinel)) + { + ValueTaskSourceHelper.ThrowInvalidOperationException("Something went wrong"); + } + + ThreadPool.UnsafeQueueUserWorkItem(s => continuation(s), state); + } + } + + /// Ensures that the specified token matches the current version. + /// The token supplied by . + private void ValidateToken(short token) + { + if (token != _version) + { + ValueTaskSourceHelper.ThrowInvalidOperationException("Version mismatch. Possible double awaiting of the value task"); + } + } + + /// Signals that the operation has completed. Invoked after the result or error has been set. + private void SignalCompletion() + { + if (_completed) + { + ValueTaskSourceHelper.ThrowInvalidOperationException("Signaling completion on non-completed task"); + } + + _completed = true; + + if (_continuation != null || Interlocked.CompareExchange(ref _continuation, ValueTaskSourceHelper.s_sentinel, null) != null) + { + InvokeContinuation(); + } + } + + /// + /// Invokes the continuation synchronously. + /// + private void InvokeContinuation() + { + Debug.Assert(_continuation != null); + _continuation(_continuationState); + } + } +} \ No newline at end of file diff --git a/src/SolarWind/Threading/ValueTaskSourceHelper.cs b/src/SolarWind/Threading/ValueTaskSourceHelper.cs new file mode 100644 index 0000000..ecd05a2 --- /dev/null +++ b/src/SolarWind/Threading/ValueTaskSourceHelper.cs @@ -0,0 +1,18 @@ +using System.Diagnostics; + +namespace System.Threading.Tasks.Sources +{ + internal static class ValueTaskSourceHelper // separated out of generic to avoid unnecessary duplication + { + internal static readonly Action s_sentinel = CompletionSentinel; + + private static void CompletionSentinel(object _) // named method to aid debugging + { + const string message = "The sentinel delegate should never be invoked."; + Debug.Fail(message); + ThrowInvalidOperationException(message); + } + + internal static void ThrowInvalidOperationException(string message) => throw new InvalidOperationException(message); + } +} \ No newline at end of file