Skip to content

Add missing ValueTask support with InProcessNoEmit toolchain #2110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions src/BenchmarkDotNet/Code/DeclarationsProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private string GetMethodName(MethodInfo method)
(method.ReturnType.GetGenericTypeDefinition() == typeof(Task<>) ||
method.ReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>))))
{
return $"() => {method.Name}().GetAwaiter().GetResult()";
return $"() => BenchmarkDotNet.Helpers.AwaitHelper.GetResult({method.Name}())";
}

return method.Name;
Expand Down Expand Up @@ -149,12 +149,10 @@ internal class TaskDeclarationsProvider : VoidDeclarationsProvider
{
public TaskDeclarationsProvider(Descriptor descriptor) : base(descriptor) { }

// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
// and will eventually throw actual exception, not aggregated one
public override string WorkloadMethodDelegate(string passArguments)
=> $"({passArguments}) => {{ {Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult(); }}";
=> $"({passArguments}) => {{ BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments})); }}";

public override string GetWorkloadMethodCall(string passArguments) => $"{Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult()";
public override string GetWorkloadMethodCall(string passArguments) => $"BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments}))";

protected override Type WorkloadMethodReturnType => typeof(void);
}
Expand All @@ -168,11 +166,9 @@ public GenericTaskDeclarationsProvider(Descriptor descriptor) : base(descriptor)

protected override Type WorkloadMethodReturnType => Descriptor.WorkloadMethod.ReturnType.GetTypeInfo().GetGenericArguments().Single();

// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
// and will eventually throw actual exception, not aggregated one
public override string WorkloadMethodDelegate(string passArguments)
=> $"({passArguments}) => {{ return {Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult(); }}";
=> $"({passArguments}) => {{ return BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments})); }}";

public override string GetWorkloadMethodCall(string passArguments) => $"{Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult()";
public override string GetWorkloadMethodCall(string passArguments) => $"BenchmarkDotNet.Helpers.AwaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments}))";
}
}
127 changes: 127 additions & 0 deletions src/BenchmarkDotNet/Helpers/AwaitHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
using System;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

namespace BenchmarkDotNet.Helpers
{
public static class AwaitHelper
{
private class ValueTaskWaiter
{
// We use thread static field so that each thread uses its own individual callback and reset event.
[ThreadStatic]
private static ValueTaskWaiter ts_current;
internal static ValueTaskWaiter Current => ts_current ??= new ValueTaskWaiter();

// We cache the callback to prevent allocations for memory diagnoser.
private readonly Action awaiterCallback;
private readonly ManualResetEventSlim resetEvent;

private ValueTaskWaiter()
{
resetEvent = new ();
awaiterCallback = resetEvent.Set;
}

// Hook up a callback instead of converting to Task to prevent extra allocations on each benchmark run.
internal void Wait(ConfiguredValueTaskAwaitable.ConfiguredValueTaskAwaiter awaiter)
{
resetEvent.Reset();
awaiter.UnsafeOnCompleted(awaiterCallback);

// The fastest way to wait for completion is to spin a bit before waiting on the event. This is the same logic that Task.GetAwaiter().GetResult() uses.
var spinner = new SpinWait();
while (!resetEvent.IsSet)
{
if (spinner.NextSpinWillYield)
{
resetEvent.Wait();
return;
}
spinner.SpinOnce();
}
}

internal void Wait<T>(ConfiguredValueTaskAwaitable<T>.ConfiguredValueTaskAwaiter awaiter)
{
resetEvent.Reset();
awaiter.UnsafeOnCompleted(awaiterCallback);

// The fastest way to wait for completion is to spin a bit before waiting on the event. This is the same logic that Task.GetAwaiter().GetResult() uses.
var spinner = new SpinWait();
while (!resetEvent.IsSet)
{
if (spinner.NextSpinWillYield)
{
resetEvent.Wait();
return;
}
spinner.SpinOnce();
}
}
}

// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
// and will eventually throw actual exception, not aggregated one
public static void GetResult(Task task) => task.GetAwaiter().GetResult();

public static T GetResult<T>(Task<T> task) => task.GetAwaiter().GetResult();

// ValueTask can be backed by an IValueTaskSource that only supports asynchronous awaits, so we have to hook up a callback instead of calling .GetAwaiter().GetResult() like we do for Task.
// The alternative is to convert it to Task using .AsTask(), but that causes allocations which we must avoid for memory diagnoser.
public static void GetResult(ValueTask task)
{
// Don't continue on the captured context, as that may result in a deadlock if the user runs this in-process.
var awaiter = task.ConfigureAwait(false).GetAwaiter();
if (!awaiter.IsCompleted)
{
ValueTaskWaiter.Current.Wait(awaiter);
}
awaiter.GetResult();
}

public static T GetResult<T>(ValueTask<T> task)
{
// Don't continue on the captured context, as that may result in a deadlock if the user runs this in-process.
var awaiter = task.ConfigureAwait(false).GetAwaiter();
if (!awaiter.IsCompleted)
{
ValueTaskWaiter.Current.Wait(awaiter);
}
return awaiter.GetResult();
}

internal static MethodInfo GetGetResultMethod(Type taskType)
{
if (!taskType.IsGenericType)
{
return typeof(AwaitHelper).GetMethod(nameof(AwaitHelper.GetResult), BindingFlags.Public | BindingFlags.Static, null, new Type[1] { taskType }, null);
}

Type compareType = taskType.GetGenericTypeDefinition() == typeof(ValueTask<>) ? typeof(ValueTask<>)
: typeof(Task).IsAssignableFrom(taskType.GetGenericTypeDefinition()) ? typeof(Task<>)
: null;
if (compareType == null)
{
return null;
}
var resultType = taskType
.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)
.ReturnType
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance)
.ReturnType;
return typeof(AwaitHelper).GetMethods(BindingFlags.Public | BindingFlags.Static)
.First(m =>
{
if (m.Name != nameof(AwaitHelper.GetResult)) return false;
Type paramType = m.GetParameters().First().ParameterType;
// We have to compare the types indirectly, == check doesn't work.
return paramType.Assembly == compareType.Assembly && paramType.Namespace == compareType.Namespace && paramType.Name == compareType.Name;
})
.MakeGenericMethod(new[] { resultType });
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using BenchmarkDotNet.Engines;
using JetBrains.Annotations;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
Expand All @@ -17,28 +19,24 @@ public ConsumableTypeInfo(Type methodReturnType)

OriginMethodReturnType = methodReturnType;

// Please note this code does not support await over extension methods.
var getAwaiterMethod = methodReturnType.GetMethod(nameof(Task<int>.GetAwaiter), BindingFlagsPublicInstance);
if (getAwaiterMethod == null)
// Only support (Value)Task for parity with other toolchains (and so we can use AwaitHelper).
IsAwaitable = methodReturnType == typeof(Task) || methodReturnType == typeof(ValueTask)
|| (methodReturnType.GetTypeInfo().IsGenericType
&& (methodReturnType.GetTypeInfo().GetGenericTypeDefinition() == typeof(Task<>)
|| methodReturnType.GetTypeInfo().GetGenericTypeDefinition() == typeof(ValueTask<>)));

if (!IsAwaitable)
{
WorkloadMethodReturnType = methodReturnType;
}
else
{
var getResultMethod = getAwaiterMethod
WorkloadMethodReturnType = methodReturnType
.GetMethod(nameof(Task.GetAwaiter), BindingFlagsPublicInstance)
.ReturnType
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlagsPublicInstance);

if (getResultMethod == null)
{
WorkloadMethodReturnType = methodReturnType;
}
else
{
WorkloadMethodReturnType = getResultMethod.ReturnType;
GetAwaiterMethod = getAwaiterMethod;
GetResultMethod = getResultMethod;
}
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlagsPublicInstance)
.ReturnType;
GetResultMethod = Helpers.AwaitHelper.GetGetResultMethod(methodReturnType);
}

if (WorkloadMethodReturnType == null)
Expand Down Expand Up @@ -75,14 +73,13 @@ public ConsumableTypeInfo(Type methodReturnType)
public Type WorkloadMethodReturnType { get; }
public Type OverheadMethodReturnType { get; }

public MethodInfo? GetAwaiterMethod { get; }
public MethodInfo? GetResultMethod { get; }

public bool IsVoid { get; }
public bool IsByRef { get; }
public bool IsConsumable { get; }
public FieldInfo? WorkloadConsumableField { get; }

public bool IsAwaitable => GetAwaiterMethod != null && GetResultMethod != null;
public bool IsAwaitable { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -580,42 +580,28 @@ private MethodInfo EmitWorkloadImplementation(string methodName)
workloadInvokeMethod.ReturnParameter,
args);
args = methodBuilder.GetEmitParameters(args);
var callResultType = consumableInfo.OriginMethodReturnType;
var awaiterType = consumableInfo.GetAwaiterMethod?.ReturnType
?? throw new InvalidOperationException($"Bug: {nameof(consumableInfo.GetAwaiterMethod)} is null");

var ilBuilder = methodBuilder.GetILGenerator();

/*
.locals init (
[0] valuetype [mscorlib]System.Runtime.CompilerServices.TaskAwaiter`1<int32>
)
*/
var callResultLocal =
ilBuilder.DeclareOptionalLocalForInstanceCall(callResultType, consumableInfo.GetAwaiterMethod);
var awaiterLocal =
ilBuilder.DeclareOptionalLocalForInstanceCall(awaiterType, consumableInfo.GetResultMethod);

/*
// return TaskSample(arg0). ... ;
IL_0000: ldarg.0
IL_0001: ldarg.1
IL_0002: call instance class [mscorlib]System.Threading.Tasks.Task`1<int32> [BenchmarkDotNet]BenchmarkDotNet.Samples.SampleBenchmark::TaskSample(int64)
*/
IL_0026: ldarg.0
IL_0027: ldloc.0
IL_0028: ldloc.1
IL_0029: ldloc.2
IL_002a: ldloc.3
IL_002b: call instance class [System.Private.CoreLib]System.Threading.Tasks.Task`1<object> BenchmarkDotNet.Helpers.Runnable_0::WorkloadMethod(string, string, string, string)
*/
if (!Descriptor.WorkloadMethod.IsStatic)
ilBuilder.Emit(OpCodes.Ldarg_0);
ilBuilder.EmitLdargs(args);
ilBuilder.Emit(OpCodes.Call, Descriptor.WorkloadMethod);

/*
// ... .GetAwaiter().GetResult();
IL_0007: callvirt instance valuetype [mscorlib]System.Runtime.CompilerServices.TaskAwaiter`1<!0> class [mscorlib]System.Threading.Tasks.Task`1<int32>::GetAwaiter()
IL_000c: stloc.0
IL_000d: ldloca.s 0
IL_000f: call instance !0 valuetype [mscorlib]System.Runtime.CompilerServices.TaskAwaiter`1<int32>::GetResult()
*/
ilBuilder.EmitInstanceCallThisValueOnStack(callResultLocal, consumableInfo.GetAwaiterMethod);
ilBuilder.EmitInstanceCallThisValueOnStack(awaiterLocal, consumableInfo.GetResultMethod);
// BenchmarkDotNet.Helpers.AwaitHelper.GetResult(...);
IL_000e: call !!0 BenchmarkDotNet.Helpers.AwaitHelper::GetResult<int32>(valuetype [System.Runtime]System.Threading.Tasks.ValueTask`1<!!0>)
*/

ilBuilder.Emit(OpCodes.Call, consumableInfo.GetResultMethod);

/*
IL_0014: ret
Expand Down Expand Up @@ -830,19 +816,6 @@ .locals init (
var skipFirstArg = workloadMethod.IsStatic;
var argLocals = EmitDeclareArgLocals(ilBuilder, skipFirstArg);

LocalBuilder callResultLocal = null;
LocalBuilder awaiterLocal = null;
if (consumableInfo.IsAwaitable)
{
var callResultType = consumableInfo.OriginMethodReturnType;
var awaiterType = consumableInfo.GetAwaiterMethod?.ReturnType
?? throw new InvalidOperationException($"Bug: {nameof(consumableInfo.GetAwaiterMethod)} is null");
callResultLocal =
ilBuilder.DeclareOptionalLocalForInstanceCall(callResultType, consumableInfo.GetAwaiterMethod);
awaiterLocal =
ilBuilder.DeclareOptionalLocalForInstanceCall(awaiterType, consumableInfo.GetResultMethod);
}

consumeEmitter.DeclareDisassemblyDiagnoserLocals(ilBuilder);

var notElevenLabel = ilBuilder.DefineLabel();
Expand All @@ -867,29 +840,27 @@ .locals init (
EmitLoadArgFieldsToLocals(ilBuilder, argLocals, skipFirstArg);

/*
// return TaskSample(_argField) ... ;
IL_0011: ldarg.0
IL_0012: ldloc.0
IL_0013: call instance class [mscorlib]System.Threading.Tasks.Task`1<int32> [BenchmarkDotNet]BenchmarkDotNet.Samples.SampleBenchmark::TaskSample(int64)
IL_0018: ret
IL_0026: ldarg.0
IL_0027: ldloc.0
IL_0028: ldloc.1
IL_0029: ldloc.2
IL_002a: ldloc.3
IL_002b: call instance class [System.Private.CoreLib]System.Threading.Tasks.Task`1<object> BenchmarkDotNet.Helpers.Runnable_0::WorkloadMethod(string, string, string, string)
*/

if (!workloadMethod.IsStatic)
{
ilBuilder.Emit(OpCodes.Ldarg_0);
}
ilBuilder.EmitLdLocals(argLocals);
ilBuilder.Emit(OpCodes.Call, workloadMethod);

if (consumableInfo.IsAwaitable)
{
/*
// ... .GetAwaiter().GetResult();
IL_0007: callvirt instance valuetype [mscorlib]System.Runtime.CompilerServices.TaskAwaiter`1<!0> class [mscorlib]System.Threading.Tasks.Task`1<int32>::GetAwaiter()
IL_000c: stloc.0
IL_000d: ldloca.s 0
IL_000f: call instance !0 valuetype [mscorlib]System.Runtime.CompilerServices.TaskAwaiter`1<int32>::GetResult()
*/
ilBuilder.EmitInstanceCallThisValueOnStack(callResultLocal, consumableInfo.GetAwaiterMethod);
ilBuilder.EmitInstanceCallThisValueOnStack(awaiterLocal, consumableInfo.GetResultMethod);
// BenchmarkDotNet.Helpers.AwaitHelper.GetResult(...);
IL_000e: call !!0 BenchmarkDotNet.Helpers.AwaitHelper::GetResult<int32>(valuetype [System.Runtime]System.Threading.Tasks.ValueTask`1<!!0>)
*/
ilBuilder.Emit(OpCodes.Call, consumableInfo.GetResultMethod);
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ private static BenchmarkAction CreateCore(
if (resultType == typeof(Task))
return new BenchmarkActionTask(resultInstance, targetMethod, unrollFactor);

if (resultType == typeof(ValueTask))
return new BenchmarkActionValueTask(resultInstance, targetMethod, unrollFactor);

if (resultType.GetTypeInfo().IsGenericType)
{
var genericType = resultType.GetGenericTypeDefinition();
Expand Down
Loading