Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ public static partial class AsyncHelpers
private struct RuntimeAsyncAwaitState
{
public Continuation? SentinelContinuation;
public ICriticalNotifyCompletion? CriticalNotifier;
public INotifyCompletion? Notifier;
public Task? CalledTask;
}

[ThreadStatic]
Expand Down Expand Up @@ -203,7 +205,21 @@ private static unsafe object AllocContinuationResultBox(void* ptr)
return RuntimeTypeHandle.InternalAllocNoChecks((MethodTable*)pMT);
}

private interface IThunkTaskOps<T>
[BypassReadyToRun]
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)]
[RequiresPreviewFeatures]
private static void TransparentAwaitTask(Task t)
{
ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
Continuation? sentinelContinuation = state.SentinelContinuation;
if (sentinelContinuation == null)
state.SentinelContinuation = sentinelContinuation = new Continuation();

state.CalledTask = t;
AsyncSuspend(sentinelContinuation);
}

private interface IRuntimeAsyncTaskOps<T>
{
static abstract Action GetContinuationAction(T task);
static abstract Continuation GetContinuationState(T task);
Expand All @@ -212,9 +228,12 @@ private interface IThunkTaskOps<T>
static abstract void PostToSyncContext(T task, SynchronizationContext syncCtx);
}

private sealed class ThunkTask<T> : Task<T>
/// <summary>
/// Represents a wrapped runtime async operation.
/// </summary>
private sealed class RuntimeAsyncTask<T> : Task<T>, ITaskCompletionAction
{
public ThunkTask()
public RuntimeAsyncTask()
{
// We use the base Task's state object field to store the Continuation while posting the task around.
// Ensure that state object isn't published out for others to see.
Expand All @@ -231,31 +250,38 @@ internal override void ExecuteFromThreadPool(Thread threadPoolThread)

private void MoveNext()
{
ThunkTaskCore.MoveNext<ThunkTask<T>, Ops>(this);
RuntimeAsyncTaskCore.DispatchContinuations<RuntimeAsyncTask<T>, Ops>(this);
}

public void HandleSuspended()
{
ThunkTaskCore.HandleSuspended<ThunkTask<T>, Ops>(this);
RuntimeAsyncTaskCore.HandleSuspended<RuntimeAsyncTask<T>, Ops>(this);
}

void ITaskCompletionAction.Invoke(Task completingTask)
{
MoveNext();
}

bool ITaskCompletionAction.InvokeMayRunArbitraryCode => true;

private static readonly SendOrPostCallback s_postCallback = static state =>
{
Debug.Assert(state is ThunkTask<T>);
((ThunkTask<T>)state).MoveNext();
Debug.Assert(state is RuntimeAsyncTask<T>);
((RuntimeAsyncTask<T>)state).MoveNext();
};

private struct Ops : IThunkTaskOps<ThunkTask<T>>
private struct Ops : IRuntimeAsyncTaskOps<RuntimeAsyncTask<T>>
{
public static Action GetContinuationAction(ThunkTask<T> task) => (Action)task.m_action!;
public static void MoveNext(ThunkTask<T> task) => task.MoveNext();
public static Continuation GetContinuationState(ThunkTask<T> task) => (Continuation)task.m_stateObject!;
public static void SetContinuationState(ThunkTask<T> task, Continuation value)
public static Action GetContinuationAction(RuntimeAsyncTask<T> task) => (Action)task.m_action!;
public static void MoveNext(RuntimeAsyncTask<T> task) => task.MoveNext();
public static Continuation GetContinuationState(RuntimeAsyncTask<T> task) => (Continuation)task.m_stateObject!;
public static void SetContinuationState(RuntimeAsyncTask<T> task, Continuation value)
{
task.m_stateObject = value;
}

public static bool SetCompleted(ThunkTask<T> task, Continuation continuation)
public static bool SetCompleted(RuntimeAsyncTask<T> task, Continuation continuation)
{
T result;
if (RuntimeHelpers.IsReferenceOrContainsReferences<T>())
Expand All @@ -277,16 +303,19 @@ public static bool SetCompleted(ThunkTask<T> task, Continuation continuation)
return task.TrySetResult(result);
}

public static void PostToSyncContext(ThunkTask<T> task, SynchronizationContext syncContext)
public static void PostToSyncContext(RuntimeAsyncTask<T> task, SynchronizationContext syncContext)
{
syncContext.Post(s_postCallback, task);
}
}
}

private sealed class ThunkTask : Task
/// <summary>
/// Represents a wrapped runtime async operation.
/// </summary>
private sealed class RuntimeAsyncTask : Task, ITaskCompletionAction
{
public ThunkTask()
public RuntimeAsyncTask()
{
// We use the base Task's state object field to store the Continuation while posting the task around.
// Ensure that state object isn't published out for others to see.
Expand All @@ -303,45 +332,52 @@ internal override void ExecuteFromThreadPool(Thread threadPoolThread)

private void MoveNext()
{
ThunkTaskCore.MoveNext<ThunkTask, Ops>(this);
RuntimeAsyncTaskCore.DispatchContinuations<RuntimeAsyncTask, Ops>(this);
}

public void HandleSuspended()
{
ThunkTaskCore.HandleSuspended<ThunkTask, Ops>(this);
RuntimeAsyncTaskCore.HandleSuspended<RuntimeAsyncTask, Ops>(this);
}

void ITaskCompletionAction.Invoke(Task completingTask)
{
MoveNext();
}

bool ITaskCompletionAction.InvokeMayRunArbitraryCode => true;

private static readonly SendOrPostCallback s_postCallback = static state =>
{
Debug.Assert(state is ThunkTask);
((ThunkTask)state).MoveNext();
Debug.Assert(state is RuntimeAsyncTask);
((RuntimeAsyncTask)state).MoveNext();
};

private struct Ops : IThunkTaskOps<ThunkTask>
private struct Ops : IRuntimeAsyncTaskOps<RuntimeAsyncTask>
{
public static Action GetContinuationAction(ThunkTask task) => (Action)task.m_action!;
public static void MoveNext(ThunkTask task) => task.MoveNext();
public static Continuation GetContinuationState(ThunkTask task) => (Continuation)task.m_stateObject!;
public static void SetContinuationState(ThunkTask task, Continuation value)
public static Action GetContinuationAction(RuntimeAsyncTask task) => (Action)task.m_action!;
public static void MoveNext(RuntimeAsyncTask task) => task.MoveNext();
public static Continuation GetContinuationState(RuntimeAsyncTask task) => (Continuation)task.m_stateObject!;
public static void SetContinuationState(RuntimeAsyncTask task, Continuation value)
{
task.m_stateObject = value;
}

public static bool SetCompleted(ThunkTask task, Continuation continuation)
public static bool SetCompleted(RuntimeAsyncTask task, Continuation continuation)
{
return task.TrySetResult();
}

public static void PostToSyncContext(ThunkTask task, SynchronizationContext syncContext)
public static void PostToSyncContext(RuntimeAsyncTask task, SynchronizationContext syncContext)
{
syncContext.Post(s_postCallback, task);
}
}
}

private static class ThunkTaskCore
private static class RuntimeAsyncTaskCore
{
public static unsafe void MoveNext<T, TOps>(T task) where T : Task where TOps : IThunkTaskOps<T>
public static unsafe void DispatchContinuations<T, TOps>(T task) where T : Task, ITaskCompletionAction where TOps : IRuntimeAsyncTaskOps<T>
{
ExecutionAndSyncBlockStore contexts = default;
contexts.Push();
Expand Down Expand Up @@ -422,9 +458,20 @@ private static Continuation UnwindToPossibleHandler(Continuation continuation)
}
}

public static void HandleSuspended<T, TOps>(T task) where T : Task where TOps : IThunkTaskOps<T>
public static void HandleSuspended<T, TOps>(T task) where T : Task, ITaskCompletionAction where TOps : IRuntimeAsyncTaskOps<T>
{
Continuation headContinuation = UnlinkHeadContinuation(out INotifyCompletion? notifier);
ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
ICriticalNotifyCompletion? critNotifier = state.CriticalNotifier;
INotifyCompletion? notifier = state.Notifier;
Task? calledTask = state.CalledTask;

state.CriticalNotifier = null;
state.Notifier = null;
state.CalledTask = null;

Continuation sentinelContinuation = state.SentinelContinuation!;
Continuation headContinuation = sentinelContinuation.Next!;
sentinelContinuation.Next = null;

// Head continuation should be the result of async call to AwaitAwaiter or UnsafeAwaitAwaiter.
// These never have special continuation handling.
Expand All @@ -438,9 +485,19 @@ public static void HandleSuspended<T, TOps>(T task) where T : Task where TOps :

try
{
if (notifier is ICriticalNotifyCompletion crit)
if (critNotifier != null)
{
critNotifier.UnsafeOnCompleted(TOps.GetContinuationAction(task));
}
else if (calledTask != null)
{
crit.UnsafeOnCompleted(TOps.GetContinuationAction(task));
// Runtime async callable wrapper for task returning
// method. This implements the context transparent
// forwarding and makes these wrappers minimal cost.
if (!calledTask.TryAddCompletionAction(task))
{
ThreadPool.UnsafeQueueUserWorkItemInternal(task, preferLocal: true);
}
}
else
{
Expand All @@ -454,19 +511,7 @@ public static void HandleSuspended<T, TOps>(T task) where T : Task where TOps :
}
}

private static Continuation UnlinkHeadContinuation(out INotifyCompletion? notifier)
{
ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
notifier = state.Notifier;
state.Notifier = null;

Continuation sentinelContinuation = state.SentinelContinuation!;
Continuation head = sentinelContinuation.Next!;
sentinelContinuation.Next = null;
return head;
}

private static bool QueueContinuationFollowUpActionIfNecessary<T, TOps>(T task, Continuation continuation) where T : Task where TOps : IThunkTaskOps<T>
private static bool QueueContinuationFollowUpActionIfNecessary<T, TOps>(T task, Continuation continuation) where T : Task where TOps : IRuntimeAsyncTaskOps<T>
{
if ((continuation.Flags & CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_THREAD_POOL) != 0)
{
Expand Down Expand Up @@ -554,7 +599,7 @@ private static bool QueueContinuationFollowUpActionIfNecessary<T, TOps>(T task,

continuation.Next = finalContinuation;

ThunkTask<T?> result = new();
RuntimeAsyncTask<T?> result = new();
result.HandleSuspended();
return result;
}
Expand All @@ -567,7 +612,7 @@ private static Task FinalizeTaskReturningThunk(Continuation continuation)
};
continuation.Next = finalContinuation;

ThunkTask result = new();
RuntimeAsyncTask result = new();
result.HandleSuspended();
return result;
}
Expand Down Expand Up @@ -679,5 +724,16 @@ private static void CaptureContinuationContext(SynchronizationContext syncCtx, r

flags |= CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_THREAD_POOL;
}

internal static T CompletedTaskResult<T>(Task<T> task)
{
TaskAwaiter.ValidateEnd(task);
return task.ResultOnSuccess;
}

internal static void CompletedTask(Task task)
{
TaskAwaiter.ValidateEnd(task);
}
}
}
Loading
Loading