1111using System . Runtime . Versioning ;
1212using System . Threading ;
1313using System . Threading . Tasks ;
14+ using System . Threading . Tasks . Sources ;
1415
1516namespace System . Runtime . CompilerServices
1617{
@@ -150,9 +151,14 @@ public static partial class AsyncHelpers
150151 private struct RuntimeAsyncAwaitState
151152 {
152153 public Continuation ? SentinelContinuation ;
154+
155+ // The following are the possible introducers of asynchrony into a chain of awaits.
156+ // In other words - when we build a chain of continuations it would be logicaly attached
157+ // to one of these notifiers.
153158 public ICriticalNotifyCompletion ? CriticalNotifier ;
154159 public INotifyCompletion ? Notifier ;
155- public Task ? CalledTask ;
160+ public IValueTaskSourceNotifier ? ValueTaskSourceNotifier ;
161+ public Task ? TaskNotifier ;
156162 }
157163
158164 [ ThreadStatic ]
@@ -187,17 +193,35 @@ private static unsafe Continuation AllocContinuationClass(Continuation prevConti
187193 return newContinuation ;
188194 }
189195
196+ /// <summary>
197+ /// Used by internal thunks that implement awaiting on Task or a ValueTask.
198+ /// A ValueTask may wrap:
199+ /// - Completed result (we never await this)
200+ /// - Task
201+ /// - ValueTaskSource
202+ /// Therefore, when we are awaiting a ValueTask completion we are really
203+ /// awaiting a completion of an underlying Task or ValueTaskSource.
204+ /// </summary>
205+ /// <param name="o"> Task or a ValueTaskNotifier whose completion we are awaiting.</param>
190206 [ BypassReadyToRun ]
191207 [ MethodImpl ( MethodImplOptions . NoInlining | MethodImplOptions . Async ) ]
192208 [ RequiresPreviewFeatures ]
193- private static void TransparentAwaitTask ( Task t )
209+ private static void TransparentAwait ( object o )
194210 {
195211 ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState ;
196212 Continuation ? sentinelContinuation = state . SentinelContinuation ;
197213 if ( sentinelContinuation == null )
198214 state . SentinelContinuation = sentinelContinuation = new Continuation ( ) ;
199215
200- state . CalledTask = t ;
216+ if ( o is Task t )
217+ {
218+ state . TaskNotifier = t ;
219+ }
220+ else
221+ {
222+ state . ValueTaskSourceNotifier = ( IValueTaskSourceNotifier ) o ;
223+ }
224+
201225 AsyncSuspend ( sentinelContinuation ) ;
202226 }
203227
@@ -208,6 +232,7 @@ private interface IRuntimeAsyncTaskOps<T>
208232 static abstract void SetContinuationState ( T task , Continuation value ) ;
209233 static abstract bool SetCompleted ( T task ) ;
210234 static abstract void PostToSyncContext ( T task , SynchronizationContext syncCtx ) ;
235+ static abstract void ValueTaskSourceOnCompleted ( T task , IValueTaskSourceNotifier vtsNotifier , ValueTaskSourceOnCompletedFlags configFlags ) ;
211236 static abstract ref byte GetResultStorage ( T task ) ;
212237 }
213238
@@ -253,6 +278,12 @@ void ITaskCompletionAction.Invoke(Task completingTask)
253278 ( ( RuntimeAsyncTask < T > ) state ) . MoveNext ( ) ;
254279 } ;
255280
281+ public static readonly Action < object ? > s_runContinuationAction = static state =>
282+ {
283+ Debug . Assert ( state is RuntimeAsyncTask < T > ) ;
284+ ( ( RuntimeAsyncTask < T > ) state ) . MoveNext ( ) ;
285+ } ;
286+
256287 private struct Ops : IRuntimeAsyncTaskOps < RuntimeAsyncTask < T > >
257288 {
258289 public static Action GetContinuationAction ( RuntimeAsyncTask < T > task ) => ( Action ) task . m_action ! ;
@@ -272,6 +303,11 @@ public static void PostToSyncContext(RuntimeAsyncTask<T> task, SynchronizationCo
272303 syncContext . Post ( s_postCallback , task ) ;
273304 }
274305
306+ public static void ValueTaskSourceOnCompleted ( RuntimeAsyncTask < T > task , IValueTaskSourceNotifier vtsNotifier , ValueTaskSourceOnCompletedFlags configFlags )
307+ {
308+ vtsNotifier . OnCompleted ( s_runContinuationAction , task , configFlags ) ;
309+ }
310+
275311 public static ref byte GetResultStorage ( RuntimeAsyncTask < T > task ) => ref Unsafe . As < T ? , byte > ( ref task . m_result ) ;
276312 }
277313 }
@@ -318,6 +354,12 @@ void ITaskCompletionAction.Invoke(Task completingTask)
318354 ( ( RuntimeAsyncTask ) state ) . MoveNext ( ) ;
319355 } ;
320356
357+ public static readonly Action < object ? > s_runContinuationAction = static state =>
358+ {
359+ Debug . Assert ( state is RuntimeAsyncTask ) ;
360+ ( ( RuntimeAsyncTask ) state ) . MoveNext ( ) ;
361+ } ;
362+
321363 private struct Ops : IRuntimeAsyncTaskOps < RuntimeAsyncTask >
322364 {
323365 public static Action GetContinuationAction ( RuntimeAsyncTask task ) => ( Action ) task . m_action ! ;
@@ -337,6 +379,11 @@ public static void PostToSyncContext(RuntimeAsyncTask task, SynchronizationConte
337379 syncContext . Post ( s_postCallback , task ) ;
338380 }
339381
382+ public static void ValueTaskSourceOnCompleted ( RuntimeAsyncTask task , IValueTaskSourceNotifier vtsNotifier , ValueTaskSourceOnCompletedFlags configFlags )
383+ {
384+ vtsNotifier . OnCompleted ( s_runContinuationAction , task , configFlags ) ;
385+ }
386+
340387 public static ref byte GetResultStorage ( RuntimeAsyncTask task ) => ref Unsafe . NullRef < byte > ( ) ;
341388 }
342389 }
@@ -461,13 +508,16 @@ public static unsafe void DispatchContinuations<T, TOps>(T task) where T : Task,
461508 public static void HandleSuspended < T , TOps > ( T task ) where T : Task , ITaskCompletionAction where TOps : IRuntimeAsyncTaskOps < T >
462509 {
463510 ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState ;
511+
464512 ICriticalNotifyCompletion ? critNotifier = state . CriticalNotifier ;
465513 INotifyCompletion ? notifier = state . Notifier ;
466- Task ? calledTask = state . CalledTask ;
514+ IValueTaskSourceNotifier ? vtsNotifier = state . ValueTaskSourceNotifier ;
515+ Task ? taskNotifier = state . TaskNotifier ;
467516
468517 state . CriticalNotifier = null ;
469518 state . Notifier = null ;
470- state . CalledTask = null ;
519+ state . ValueTaskSourceNotifier = null ;
520+ state . TaskNotifier = null ;
471521
472522 Continuation sentinelContinuation = state . SentinelContinuation ! ;
473523 Continuation headContinuation = sentinelContinuation . Next ! ;
@@ -489,16 +539,43 @@ public static void HandleSuspended<T, TOps>(T task) where T : Task, ITaskComplet
489539 {
490540 critNotifier . UnsafeOnCompleted ( TOps . GetContinuationAction ( task ) ) ;
491541 }
492- else if ( calledTask != null )
542+ else if ( taskNotifier != null )
493543 {
494544 // Runtime async callable wrapper for task returning
495545 // method. This implements the context transparent
496546 // forwarding and makes these wrappers minimal cost.
497- if ( ! calledTask . TryAddCompletionAction ( task ) )
547+ if ( ! taskNotifier . TryAddCompletionAction ( task ) )
498548 {
499549 ThreadPool . UnsafeQueueUserWorkItemInternal ( task , preferLocal : true ) ;
500550 }
501551 }
552+ else if ( vtsNotifier != null )
553+ {
554+ // The awaiter must inform the ValueTaskSource source on whether the continuation
555+ // wants to run on a context, although the source may decide to ignore the suggestion.
556+ // Since the behavior of the source takes precedence, we clear the context flags of
557+ // the awaiting continuation (so it will run transparently on what the source decides)
558+ // and then tell the source if the awaiting frame prefers to continue on a context.
559+ // The reason why we do it here and not when the notifier is created is because
560+ // the continuation chain builds from the innermost frame out and at the time when the
561+ // notifier is created we do not know yet if the caller wants to continue on a context.
562+ ValueTaskSourceOnCompletedFlags configFlags = ValueTaskSourceOnCompletedFlags . None ;
563+ ContinuationFlags continuationFlags = headContinuation . Next ! . Flags ;
564+
565+ const ContinuationFlags continueOnContextFlags =
566+ ContinuationFlags . ContinueOnCapturedSynchronizationContext |
567+ ContinuationFlags . ContinueOnCapturedTaskScheduler ;
568+
569+ if ( ( continuationFlags & continueOnContextFlags ) != 0 )
570+ {
571+ // if await has captured some context, inform the source
572+ configFlags |= ValueTaskSourceOnCompletedFlags . UseSchedulingContext ;
573+ }
574+
575+ // Clear continuation flags, so that continuation runs transparently
576+ headContinuation . Next ! . Flags &= ~ continueFlags ;
577+ TOps . ValueTaskSourceOnCompleted ( task , vtsNotifier , configFlags ) ;
578+ }
502579 else
503580 {
504581 Debug . Assert ( notifier != null ) ;
0 commit comments