diff --git a/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`2.cs b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`2.cs new file mode 100644 index 00000000000..162130dd915 --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncInstanceMethodInfo1`2.cs @@ -0,0 +1,89 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// The method is an instance method on an object of type . + /// One generic type argument. + /// A return value of . + /// + /// + /// The type of the receiver of the instance method. + /// The type of the return value of the method. + internal sealed class FuncInstanceMethodInfo1 : FuncMethodInfo1 + where TTarget : class + { + private static readonly string _targetTypeCheckMessage = $"Should have a target type of '{typeof(TTarget)}'"; + + public FuncInstanceMethodInfo1(Func function) + : this(function.Method) + { + } + + private FuncInstanceMethodInfo1(MethodInfo methodInfo) + : base(methodInfo) + { + Contracts.CheckParam(!GenericMethodDefinition.IsStatic, nameof(methodInfo), "Should be an instance method"); + Contracts.CheckParam(GenericMethodDefinition.DeclaringType == typeof(TTarget), nameof(methodInfo), _targetTypeCheckMessage); + } + + /// + /// Creates a representing the for + /// a generic instance method. This helper method allows the instance to be created prior to the creation of any + /// instances of the target type. The following example shows the creation of an instance representing the + /// method: + /// + /// + /// FuncInstanceMethodInfo1<object, int>.Create(obj => obj.GetHashCode) + /// + /// + /// The expression which creates the delegate for an instance of the target type. + /// A representing the + /// for the generic instance method. + public static FuncInstanceMethodInfo1 Create(Expression>> expression) + { + if (!(expression is { Body: UnaryExpression { Operand: MethodCallExpression methodCallExpression } })) + { + throw Contracts.ExceptParam(nameof(expression), "Unexpected expression form"); + } + + // Verify that we are calling MethodInfo.CreateDelegate(Type, object) + Contracts.CheckParam(methodCallExpression.Method.DeclaringType == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.Name == nameof(MethodInfo.CreateDelegate), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters().Length == 2, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters()[0].ParameterType == typeof(Type), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Method.GetParameters()[1].ParameterType == typeof(object), nameof(expression), "Unexpected expression form"); + + // Verify that we are creating a delegate of type Func + Contracts.CheckParam(methodCallExpression.Arguments.Count == 2, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[0] is ConstantExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((ConstantExpression)methodCallExpression.Arguments[0]).Type == typeof(Type), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam((Type)((ConstantExpression)methodCallExpression.Arguments[0]).Value == typeof(Func), nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[1] is ParameterExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(methodCallExpression.Arguments[1] == expression.Parameters[0], nameof(expression), "Unexpected expression form"); + + // Check the MethodInfo + Contracts.CheckParam(methodCallExpression.Object is ConstantExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((ConstantExpression)methodCallExpression.Object).Type == typeof(MethodInfo), nameof(expression), "Unexpected expression form"); + + var methodInfo = (MethodInfo)((ConstantExpression)methodCallExpression.Object).Value; + Contracts.CheckParam(expression.Body is UnaryExpression, nameof(expression), "Unexpected expression form"); + Contracts.CheckParam(((UnaryExpression)expression.Body).Operand is MethodCallExpression, nameof(expression), "Unexpected expression form"); + + return new FuncInstanceMethodInfo1(methodInfo); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncMethodInfo1`1.cs b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo1`1.cs new file mode 100644 index 00000000000..9f81085d799 --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo1`1.cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Collections.Immutable; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// One generic type argument. + /// A return value of . + /// + /// + /// The type of the return value of the method. + internal abstract class FuncMethodInfo1 : FuncMethodInfo + { + private ImmutableDictionary _instanceMethodInfo; + + private protected FuncMethodInfo1(MethodInfo methodInfo) + : base(methodInfo) + { + _instanceMethodInfo = ImmutableDictionary.Empty; + + Contracts.CheckParam(GenericMethodDefinition.GetGenericArguments().Length == 1, nameof(methodInfo), + "Should have exactly one generic type parameter but does not"); + } + + public MethodInfo MakeGenericMethod(Type typeArg1) + { + return ImmutableInterlocked.GetOrAdd( + ref _instanceMethodInfo, + typeArg1, + (typeArg, methodInfo) => methodInfo.MakeGenericMethod(typeArg), + GenericMethodDefinition); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncMethodInfo`1.cs b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo`1.cs new file mode 100644 index 00000000000..c51c70630b7 --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncMethodInfo`1.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + internal abstract class FuncMethodInfo + { + private protected FuncMethodInfo(MethodInfo methodInfo) + { + Contracts.CheckValue(methodInfo, nameof(methodInfo)); + Contracts.CheckParam(methodInfo.IsGenericMethod, nameof(methodInfo), "Should be generic but is not"); + + GenericMethodDefinition = methodInfo.GetGenericMethodDefinition(); + Contracts.CheckParam(GenericMethodDefinition.ReturnType == typeof(TResult), nameof(methodInfo), "Cannot be generic on return type"); + } + + protected MethodInfo GenericMethodDefinition { get; } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo1`1.cs b/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo1`1.cs new file mode 100644 index 00000000000..6f63431941f --- /dev/null +++ b/src/Microsoft.ML.Core/Utilities/FuncStaticMethodInfo1`1.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +using System; +using System.Reflection; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Internal.Utilities +{ + /// + /// Represents the for a generic function corresponding to , + /// with the following characteristics: + /// + /// + /// The method is static. + /// One generic type argument. + /// A return value of . + /// + /// + /// The type of the return value of the method. + internal sealed class FuncStaticMethodInfo1 : FuncMethodInfo1 + { + public FuncStaticMethodInfo1(Func function) + : base(function.Method) + { + Contracts.CheckParam(GenericMethodDefinition.IsStatic, nameof(function), "Should be a static method"); + } + } +} diff --git a/src/Microsoft.ML.Core/Utilities/Utils.cs b/src/Microsoft.ML.Core/Utilities/Utils.cs index 506fd27962a..36ee1ffd890 100644 --- a/src/Microsoft.ML.Core/Utilities/Utils.cs +++ b/src/Microsoft.ML.Core/Utilities/Utils.cs @@ -988,20 +988,32 @@ private static MethodInfo MarshalInvokeCheckAndCreate(Type[] genArgs, Dele /// Because it is strongly typed, this can only be applied to methods whose return type /// is known at compile time, that is, that do not depend on the type parameter of the method itself. /// - /// The return value + /// The type of the receiver of the instance method. + /// The type of the return value of the method. /// A delegate that should be a generic method with a single type parameter. /// The generic method definition will be extracted, then a new method will be created with the /// given type parameter, then the method will be invoked. + /// The target of the invocation. /// The new type parameter for the generic method /// The return value of the invoked function - public static TRet MarshalInvoke(Func func, Type genArg) + public static TResult MarshalInvoke(FuncInstanceMethodInfo1 func, TTarget target, Type genArg) + where TTarget : class { - var meth = MarshalInvokeCheckAndCreate(genArg, func); - return (TRet)meth.Invoke(func.Target, null); + var meth = func.MakeGenericMethod(genArg); + return (TResult)meth.Invoke(target, null); + } + + /// + /// A static version of . + /// + public static TResult MarshalInvoke(FuncStaticMethodInfo1 func, Type genArg) + { + var meth = func.MakeGenericMethod(genArg); + return (TResult)meth.Invoke(null, null); } /// - /// A one-argument version of . + /// A one-argument version of . /// public static TRet MarshalInvoke(Func func, Type genArg, TArg1 arg1) { @@ -1010,7 +1022,7 @@ public static TRet MarshalInvoke(Func func, Type genAr } /// - /// A two-argument version of . + /// A two-argument version of . /// public static TRet MarshalInvoke(Func func, Type genArg, TArg1 arg1, TArg2 arg2) { @@ -1019,7 +1031,7 @@ public static TRet MarshalInvoke(Func fu } /// - /// A three-argument version of . + /// A three-argument version of . /// public static TRet MarshalInvoke(Func func, Type genArg, TArg1 arg1, TArg2 arg2, TArg3 arg3) @@ -1029,7 +1041,7 @@ public static TRet MarshalInvoke(Func - /// A four-argument version of . + /// A four-argument version of . /// public static TRet MarshalInvoke(Func func, Type genArg, TArg1 arg1, TArg2 arg2, TArg3 arg3, TArg4 arg4) @@ -1039,7 +1051,7 @@ public static TRet MarshalInvoke(Func - /// A five-argument version of . + /// A five-argument version of . /// public static TRet MarshalInvoke(Func func, Type genArg, TArg1 arg1, TArg2 arg2, TArg3 arg3, TArg4 arg4, TArg5 arg5) @@ -1049,7 +1061,7 @@ public static TRet MarshalInvoke(Func - /// A six-argument version of . + /// A six-argument version of . /// public static TRet MarshalInvoke(Func func, Type genArg, TArg1 arg1, TArg2 arg2, TArg3 arg3, TArg4 arg4, TArg5 arg5, TArg6 arg6) @@ -1059,7 +1071,7 @@ public static TRet MarshalInvoke } /// - /// A seven-argument version of . + /// A seven-argument version of . /// public static TRet MarshalInvoke(Func func, Type genArg, TArg1 arg1, TArg2 arg2, TArg3 arg3, TArg4 arg4, TArg5 arg5, TArg6 arg6, TArg7 arg7) @@ -1069,7 +1081,7 @@ public static TRet MarshalInvoke - /// An eight-argument version of . + /// An eight-argument version of . /// public static TRet MarshalInvoke(Func func, Type genArg, TArg1 arg1, TArg2 arg2, TArg3 arg3, TArg4 arg4, TArg5 arg5, TArg6 arg6, TArg7 arg7, TArg8 arg8) @@ -1079,7 +1091,7 @@ public static TRet MarshalInvoke - /// A nine-argument version of . + /// A nine-argument version of . /// public static TRet MarshalInvoke( Func func, @@ -1090,7 +1102,7 @@ public static TRet MarshalInvoke - /// A ten-argument version of . + /// A ten-argument version of . /// public static TRet MarshalInvoke( Func func, @@ -1101,7 +1113,7 @@ public static TRet MarshalInvoke - /// A 1 argument and n type version of . + /// A 1 argument and n type version of . /// public static TRet MarshalInvoke( Func func, @@ -1112,7 +1124,7 @@ public static TRet MarshalInvoke( } /// - /// A 2 argument and n type version of . + /// A 2 argument and n type version of . /// public static TRet MarshalInvoke( Func func, @@ -1147,7 +1159,7 @@ private static MethodInfo MarshalActionInvokeCheckAndCreate(Type[] typeArguments } /// - /// This is akin to , except applied to + /// This is akin to , except applied to /// instead of . /// /// A delegate that should be a generic method with a single type parameter. diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs index 6e9c4fd9866..4d58d82cfdf 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs @@ -1240,6 +1240,9 @@ public DataViewRowCursor[] GetRowCursorSet(IEnumerable co private sealed class Cursor : RootCursorBase { + private static readonly FuncInstanceMethodInfo1 _noRowGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.NoRowGetter); + private readonly BinaryLoader _parent; private readonly int[] _colToActivesIndex; private readonly TableOfContentsEntry[] _actives; @@ -2071,7 +2074,7 @@ public override ValueGetter GetGetter(DataViewSchema.Column colu /// a delegate that simply always throws. /// private Delegate GetNoRowGetter(DataViewType type) - => Utils.MarshalInvoke(NoRowGetter, type.RawType); + => Utils.MarshalInvoke(_noRowGetterMethodInfo, this, type.RawType); private Delegate NoRowGetter() { diff --git a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs index eeb885efe2f..5427bcdca5b 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/FakeSchema.cs @@ -15,6 +15,9 @@ namespace Microsoft.ML.Data.DataLoadSave [BestFriend] internal static class FakeSchemaFactory { + private static readonly FuncStaticMethodInfo1 _getDefaultVectorGetterMethodInfo = new FuncStaticMethodInfo1(GetDefaultVectorGetter); + private static readonly FuncStaticMethodInfo1 _getDefaultGetterMethodInfo = new FuncStaticMethodInfo1(GetDefaultGetter); + private const int AllVectorSizes = 10; private const int AllKeySizes = 10; @@ -31,9 +34,9 @@ public static DataViewSchema Create(SchemaShape shape) var metaColumnType = MakeColumnType(partialAnnotations[j]); Delegate del; if (metaColumnType is VectorDataViewType vectorType) - del = Utils.MarshalInvoke(GetDefaultVectorGetter, vectorType.ItemType.RawType); + del = Utils.MarshalInvoke(_getDefaultVectorGetterMethodInfo, vectorType.ItemType.RawType); else - del = Utils.MarshalInvoke(GetDefaultGetter, metaColumnType.RawType); + del = Utils.MarshalInvoke(_getDefaultGetterMethodInfo, metaColumnType.RawType); metaBuilder.Add(partialAnnotations[j].Name, metaColumnType, del); } builder.AddColumn(shape[i].Name, MakeColumnType(shape[i]), metaBuilder.ToAnnotations()); diff --git a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs index 6df350c20c7..868f0611737 100644 --- a/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs @@ -832,6 +832,15 @@ private protected AnnotationInfo(string kind, DataViewType annotationType) /// Type of the annotation value. internal sealed class AnnotationInfo : AnnotationInfo { + private static readonly FuncInstanceMethodInfo1, Delegate> _getArrayGetterMethodInfo + = FuncInstanceMethodInfo1, Delegate>.Create(target => target.GetArrayGetter); + + private static readonly FuncInstanceMethodInfo1, Delegate> _getGetterCoreMethodInfo + = FuncInstanceMethodInfo1, Delegate>.Create(target => target.GetGetterCore); + + private static readonly FuncInstanceMethodInfo1, Delegate> _getVBufferGetterMethodInfo + = FuncInstanceMethodInfo1, Delegate>.Create(target => target.GetVBufferGetter); + public readonly T Value; /// @@ -900,10 +909,7 @@ public override ValueGetter GetGetter() // T[] -> VBuffer Contracts.Check(itemType == dstItemType); - Func>> srcMethod = GetArrayGetter; - - return srcMethod.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(dstItemType) - .Invoke(this, new object[] { }) as ValueGetter; + return Utils.MarshalInvoke(_getArrayGetterMethodInfo, this, dstItemType) as ValueGetter; } if (AnnotationType is VectorDataViewType annotationVectorType) { @@ -919,10 +925,7 @@ public override ValueGetter GetGetter() Contracts.Assert(itemType == annotationVectorType.ItemType.RawType); Contracts.Check(itemType == dstItemType); - Func>> srcMethod = GetVBufferGetter; - return srcMethod.GetMethodInfo().GetGenericMethodDefinition() - .MakeGenericMethod(annotationVectorType.ItemType.RawType) - .Invoke(this, new object[] { }) as ValueGetter; + return Utils.MarshalInvoke(_getVBufferGetterMethodInfo, this, annotationVectorType.ItemType.RawType) as ValueGetter; } if (AnnotationType is PrimitiveDataViewType) { @@ -948,7 +951,7 @@ private Delegate GetGetterCore() internal override Delegate GetGetterDelegate() { - return Utils.MarshalInvoke(GetGetterCore, AnnotationType.RawType); + return Utils.MarshalInvoke(_getGetterCoreMethodInfo, this, AnnotationType.RawType); } private void GetStringArray(ref VBuffer> dst) diff --git a/src/Microsoft.ML.Data/DataView/Transposer.cs b/src/Microsoft.ML.Data/DataView/Transposer.cs index c40bd9727e3..fb70b731f84 100644 --- a/src/Microsoft.ML.Data/DataView/Transposer.cs +++ b/src/Microsoft.ML.Data/DataView/Transposer.cs @@ -1318,6 +1318,9 @@ public override ValueGetter GetGetter(DataViewSchema.Column colu internal static class TransposerUtils { + private static readonly FuncInstanceMethodInfo1 _slotCursorGetGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.GetGetter); + /// /// This is a convenience method that extracts a single slot value's vector, /// while simultaneously verifying that there is exactly one value. @@ -1359,9 +1362,7 @@ public static ValueGetter GetGetterWithVectorType(this SlotCurso var genTypeArgs = type.GetGenericArguments(); ctx.Assert(genTypeArgs.Length == 1); - Func>> del = cursor.GetGetter; - var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(genTypeArgs[0]); - var getter = methodInfo.Invoke(cursor, null) as ValueGetter; + var getter = Utils.MarshalInvoke(_slotCursorGetGetterMethodInfo, cursor, genTypeArgs[0]) as ValueGetter; if (getter == null) throw ctx.Except("Invalid TValue: '{0}'", typeof(TValue)); return getter; diff --git a/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs b/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs index 7f295abe7b4..147377f7143 100644 --- a/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs +++ b/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs @@ -293,9 +293,9 @@ public void AddInputVariable(VariableBinding binding, Type type) _ectx.AssertValue(type); if (binding is ArrayIndexVariableBinding) - type = Utils.MarshalInvoke(MakeArray, type); + type = type.MakeArrayType(); else if (binding is DictionaryKeyVariableBinding) - type = Utils.MarshalInvoke(MakeDictionary, type); + type = typeof(Dictionary<,>).MakeGenericType(typeof(string), type); EntryPointVariable v; if (!_vars.TryGetValue(binding.VariableName, out v)) @@ -308,16 +308,6 @@ public void AddInputVariable(VariableBinding binding, Type type) v.MarkUsage(true); } - private Type MakeArray() - { - return typeof(T[]); - } - - private Type MakeDictionary() - { - return typeof(Dictionary); - } - public void RemoveVariable(EntryPointVariable variable) { _ectx.CheckValue(variable, nameof(variable)); diff --git a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs index be78dab3dc5..3de842254eb 100644 --- a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs @@ -445,6 +445,12 @@ private protected override IRowMapper MakeRowMapper(DataViewSchema schema) private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { + private static readonly FuncInstanceMethodInfo1 _makeOneTrivialGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.MakeOneTrivialGetter); + + private static readonly FuncInstanceMethodInfo1 _makeVecTrivialGetterMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.MakeVecTrivialGetter); + private readonly SlotsDroppingTransformer _parent; private readonly int[] _cols; private readonly DataViewType[] _srcTypes; @@ -732,9 +738,7 @@ private Delegate MakeOneTrivialGetter(DataViewRow input, int iinfo) Host.Assert(!(_srcTypes[iinfo] is VectorDataViewType)); Host.Assert(_suppressed[iinfo]); - Func> del = MakeOneTrivialGetter; - var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(_srcTypes[iinfo].RawType); - return (Delegate)methodInfo.Invoke(this, new object[0]); + return Utils.MarshalInvoke(_makeOneTrivialGetterMethodInfo, this, _srcTypes[iinfo].RawType); } private ValueGetter MakeOneTrivialGetter() @@ -755,9 +759,7 @@ private Delegate MakeVecTrivialGetter(DataViewRow input, int iinfo) VectorDataViewType vectorType = (VectorDataViewType)_srcTypes[iinfo]; Host.Assert(_suppressed[iinfo]); - Func>> del = MakeVecTrivialGetter; - var methodInfo = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(vectorType.ItemType.RawType); - return (Delegate)methodInfo.Invoke(this, new object[0]); + return Utils.MarshalInvoke(_makeVecTrivialGetterMethodInfo, this, vectorType.ItemType.RawType); } private ValueGetter> MakeVecTrivialGetter() diff --git a/src/Microsoft.ML.Transforms/Expression/CodeGen.cs b/src/Microsoft.ML.Transforms/Expression/CodeGen.cs index f5ebef54887..6ebd4d59fbb 100644 --- a/src/Microsoft.ML.Transforms/Expression/CodeGen.cs +++ b/src/Microsoft.ML.Transforms/Expression/CodeGen.cs @@ -99,11 +99,12 @@ internal sealed partial class LambdaCompiler { private sealed class Visitor : ExprVisitor { + private static readonly MethodInfo _methGetFalseBL = ((Func)BuiltinFunctions.False).GetMethodInfo(); + private static readonly MethodInfo _methGetTrueBL = ((Func)BuiltinFunctions.True).GetMethodInfo(); + private MethodGenerator _meth; private ILGenerator _gen; private List _errors; - private readonly MethodInfo _methGetFalseBL; - private readonly MethodInfo _methGetTrueBL; private sealed class CachedWithLocal { @@ -141,14 +142,6 @@ public Visitor(MethodGenerator meth) _meth = meth; _gen = meth.Il; - Func f = BuiltinFunctions.False; - Contracts.Assert(f.Target == null); - _methGetFalseBL = f.GetMethodInfo(); - - Func t = BuiltinFunctions.True; - Contracts.Assert(t.Target == null); - _methGetTrueBL = t.GetMethodInfo(); - _cacheWith = new List(); } diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs index 8a9f8877517..0e0d1a819b5 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs @@ -371,14 +371,11 @@ private BitArray ComputeDefaultSlots(DataViewType type, T[] values) private object GetDefault(DataViewType type) { - Func func = GetDefault; - var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(type.GetItemType().RawType); - return meth.Invoke(this, null); - } + var rawType = type.GetItemType().RawType; + if (rawType.IsValueType) + return Activator.CreateInstance(rawType); - private object GetDefault() - { - return default(T); + return null; } /// diff --git a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs index 5de451ba7d8..69c041293c1 100644 --- a/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs +++ b/src/Microsoft.ML.Transforms/OptionalColumnTransform.cs @@ -245,6 +245,9 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(OptionalColumnTransform).Assembly.FullName); } + private static readonly FuncInstanceMethodInfo1 _makeGetterOneMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.MakeGetterOne); + private readonly Bindings _bindings; private const string RegistrationName = "OptionalColumn"; @@ -404,7 +407,7 @@ private Delegate MakeGetter(int iinfo) var columnType = _bindings.ColumnTypes[iinfo]; if (columnType is VectorDataViewType vectorType) return Utils.MarshalInvoke(MakeGetterVec, vectorType.ItemType.RawType, vectorType.Size); - return Utils.MarshalInvoke(MakeGetterOne, columnType.RawType); + return Utils.MarshalInvoke(_makeGetterOneMethodInfo, this, columnType.RawType); } private Delegate MakeGetterOne() @@ -420,6 +423,9 @@ private Delegate MakeGetterVec(int length) private sealed class Cursor : SynchronizedCursorBase { + private static readonly FuncInstanceMethodInfo1 _makeGetterOneMethodInfo + = FuncInstanceMethodInfo1.Create(target => target.MakeGetterOne); + private readonly Bindings _bindings; private readonly bool[] _active; private readonly Delegate[] _getters; @@ -484,7 +490,7 @@ private Delegate MakeGetter(int iinfo) var columnType = _bindings.ColumnTypes[iinfo]; if (columnType is VectorDataViewType vectorType) return Utils.MarshalInvoke(MakeGetterVec, vectorType.ItemType.RawType, vectorType.Size); - return Utils.MarshalInvoke(MakeGetterOne, columnType.RawType); + return Utils.MarshalInvoke(_makeGetterOneMethodInfo, this, columnType.RawType); } private Delegate MakeGetterOne() diff --git a/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs b/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs index 6395282a640..25eae9e6234 100644 --- a/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs +++ b/test/Microsoft.ML.Predictor.Tests/CmdLine/CmdLineReverseTest.cs @@ -82,7 +82,8 @@ public void NewTest() { var ml = new MLContext(1); ml.AddStandardComponents(); - var classes = Utils.MarshalInvoke(ml.ComponentCatalog.FindLoadableClasses, typeof(SignatureCalibrator)); + var findLoadableClassesMethodInfo = new FuncInstanceMethodInfo1(ml.ComponentCatalog.FindLoadableClasses); + var classes = Utils.MarshalInvoke(findLoadableClassesMethodInfo, ml.ComponentCatalog, typeof(SignatureCalibrator)); foreach (var cls in classes) { var factory = CmdParser.CreateComponentFactory(typeof(IComponentFactory), typeof(SignatureCalibrator), cls.LoadNames[0]);