Skip to content

Commit

Permalink
Fixes dotnet#4385 about calling the Create methods when loading model…
Browse files Browse the repository at this point in the history
…s from disk (dotnet#4485)

* Changed ComponentCatalog so it would use the most public "initter"
* Changed the visibility of several constructors and create methods so to choose the correct initter when loading models from disk
  • Loading branch information
antoniovs1029 authored Nov 27, 2019
1 parent 549b389 commit e63fa8f
Show file tree
Hide file tree
Showing 41 changed files with 150 additions and 64 deletions.
100 changes: 93 additions & 7 deletions src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,54 @@

namespace Microsoft.ML.Runtime
{

internal static class Extension
{
internal static AccessModifier Accessmodifier(this MethodInfo methodInfo)
{
if (methodInfo.IsFamilyAndAssembly)
return AccessModifier.PrivateProtected;
if (methodInfo.IsPrivate)
return AccessModifier.Private;
if (methodInfo.IsFamily)
return AccessModifier.Protected;
if (methodInfo.IsFamilyOrAssembly)
return AccessModifier.ProtectedInternal;
if (methodInfo.IsAssembly)
return AccessModifier.Internal;
if (methodInfo.IsPublic)
return AccessModifier.Public;
throw new ArgumentException("Did not find access modifier", "methodInfo");
}

internal static AccessModifier Accessmodifier(this ConstructorInfo constructorInfo)
{
if (constructorInfo.IsFamilyAndAssembly)
return AccessModifier.PrivateProtected;
if (constructorInfo.IsPrivate)
return AccessModifier.Private;
if (constructorInfo.IsFamily)
return AccessModifier.Protected;
if (constructorInfo.IsFamilyOrAssembly)
return AccessModifier.ProtectedInternal;
if (constructorInfo.IsAssembly)
return AccessModifier.Internal;
if (constructorInfo.IsPublic)
return AccessModifier.Public;
throw new ArgumentException("Did not find access modifier", "constructorInfo");
}

internal enum AccessModifier
{
PrivateProtected,
Private,
Protected,
ProtectedInternal,
Internal,
Public
}
}

/// <summary>
/// This catalogs instantiatable components (aka, loadable classes). Components are registered via
/// a descendant of <see cref="LoadableClassAttributeBase"/>, identifying the names and signature types under which the component
Expand Down Expand Up @@ -414,21 +462,59 @@ private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTyp
ctor = null;
create = null;
requireEnvironment = false;
bool requireEnvironmentCtor = false;
bool requireEnvironmentCreate = false;
var parmTypesWithEnv = Utils.Concat(new Type[1] { typeof(IHostEnvironment) }, parmTypes);

if (Utils.Size(parmTypes) == 0 && (getter = FindInstanceGetter(instType, loaderType)) != null)
return true;
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypes ?? Type.EmptyTypes, null)) != null)
return true;
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypesWithEnv ?? Type.EmptyTypes, null)) != null)

// Find both 'ctor' and 'create' methods if available
if (instType.IsAssignableFrom(loaderType))
{
if ((ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypes ?? Type.EmptyTypes, null)) == null)
{
if ((ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypesWithEnv ?? Type.EmptyTypes, null)) != null)
requireEnvironmentCtor = true;
}
}

if ((create = FindCreateMethod(instType, loaderType, parmTypes ?? Type.EmptyTypes)) == null)
{
requireEnvironment = true;
if ((create = FindCreateMethod(instType, loaderType, parmTypesWithEnv ?? Type.EmptyTypes)) != null)
requireEnvironmentCreate = true;
}

if (ctor != null && create != null)
{
// If both 'ctor' and 'create' methods were found
// Choose the one that is 'more' public
// If they have the same visibility, then throw an exception, since this shouldn't happen.

if (ctor.Accessmodifier() == create.Accessmodifier())
{
throw Contracts.Except($"Can't load type {instType}, because it has both create and constructor methods with the same visibility. Please indicate which one should be used by changing either the signature or the visibility of one of them.");
}
if (ctor.Accessmodifier() > create.Accessmodifier())
{
create = null;
requireEnvironment = requireEnvironmentCtor;
return true;
}
ctor = null;
requireEnvironment = requireEnvironmentCreate;
return true;
}
if ((create = FindCreateMethod(instType, loaderType, parmTypes ?? Type.EmptyTypes)) != null)

if (ctor != null && create == null)
{
requireEnvironment = requireEnvironmentCtor;
return true;
if ((create = FindCreateMethod(instType, loaderType, parmTypesWithEnv ?? Type.EmptyTypes)) != null)
}

if (ctor == null && create != null)
{
requireEnvironment = true;
requireEnvironment = requireEnvironmentCreate;
return true;
}

Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ private SchemaBindableCalibratedModelParameters(IHostEnvironment env, ModelLoadC
_featureContribution = SubModel as IFeatureContributionMapper;
}

private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx)
internal static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
Expand Down Expand Up @@ -1224,7 +1224,7 @@ private NaiveCalibrator(IHostEnvironment env, ModelLoadContext ctx)
_host.CheckDecode(_binProbs.All(x => (0 <= x && x <= 1)));
}

private static NaiveCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
internal static NaiveCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down Expand Up @@ -1675,7 +1675,7 @@ private PlattCalibrator(IHostEnvironment env, ModelLoadContext ctx)
_host.CheckDecode(FloatUtils.IsFinite(Offset));
}

private static PlattCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
internal static PlattCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down Expand Up @@ -1972,7 +1972,7 @@ private IsotonicCalibrator(IHostEnvironment env, ModelLoadContext ctx)
_host.CheckDecode(valuePrev <= 1);
}

private static IsotonicCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
internal static IsotonicCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
}

// Factory method for SignatureLoadModel.
private static FeatureContributionCalculatingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
internal static FeatureContributionCalculatingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
ctx.CheckAtModel(GetVersionInfo());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public LabelIndicatorTransform(IHostEnvironment env,
{
}

public LabelIndicatorTransform(IHostEnvironment env, Options options, IDataView input)
internal LabelIndicatorTransform(IHostEnvironment env, Options options, IDataView input)
: base(env, LoadName, Contracts.CheckRef(options, nameof(options)).Columns,
input, TestIsMulticlassLabel)
{
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ private SkipTakeFilter(long skip, long take, IHostEnvironment env, IDataView inp
/// <param name="env">Host Environment.</param>
/// <param name="options">Options for the skip operation.</param>
/// <param name="input">Input <see cref="IDataView"/>.</param>
public SkipTakeFilter(IHostEnvironment env, SkipOptions options, IDataView input)
internal SkipTakeFilter(IHostEnvironment env, SkipOptions options, IDataView input)
: this(options.Count, Options.DefaultTake, env, input)
{
}
Expand All @@ -112,7 +112,7 @@ public SkipTakeFilter(IHostEnvironment env, SkipOptions options, IDataView input
/// <param name="env">Host Environment.</param>
/// <param name="options">Options for the take operation.</param>
/// <param name="input">Input <see cref="IDataView"/>.</param>
public SkipTakeFilter(IHostEnvironment env, TakeOptions options, IDataView input)
internal SkipTakeFilter(IHostEnvironment env, TakeOptions options, IDataView input)
: this(Options.DefaultSkip, options.Count, env, input)
{
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ private SlotsDroppingTransformer(IHostEnvironment env, ModelLoadContext ctx)
}

// Factory method for SignatureLoadModel.
private static SlotsDroppingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
internal static SlotsDroppingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
ctx.CheckAtModel(GetVersionInfo());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ private bool IsValid(IValueMapperDist mapper, out VectorDataViewType inputType)
}
}

private static EnsembleDistributionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static EnsembleDistributionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ private bool IsValid(IValueMapper mapper, out VectorDataViewType inputType)
}
}

private static EnsembleModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static EnsembleModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ private void InitializeMappers(out IValueMapper[] mappers, out VectorDataViewTyp
inputType = new VectorDataViewType(NumberDataViewType.Single);
}

private static EnsembleMulticlassModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static EnsembleMulticlassModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
internal static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeRanking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static FastTreeRankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static FastTreeRankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
return new FastTreeRankingModelParameters(env, ctx);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static FastTreeRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static FastTreeRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeTweedie.cs
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static FastTreeTweedieModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static FastTreeTweedieModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/GamClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(GamBinaryModelParameters).Assembly.FullName);
}

private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
internal static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/GamRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(GamRegressionModelParameters).Assembly.FullName);
}

private static GamRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static GamRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/RandomForestClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
internal static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/RandomForestRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.Writer.Write(_quantileSampleCount);
}

private static FastForestRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static FastForestRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(TreeEnsembleFeaturizationTransformer).Assembly.FullName);
}

private static TreeEnsembleFeaturizationTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
internal static TreeEnsembleFeaturizationTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
=> new TreeEnsembleFeaturizationTransformer(env, ctx);
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
/// <summary>
/// This method is called by reflection to instantiate a predictor.
/// </summary>
private static KMeansModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static KMeansModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.LightGbm/LightGbmBinaryTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
internal static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static LightGbmRankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static LightGbmRankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
return new LightGbmRankingModelParameters(env, ctx);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static LightGbmRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static LightGbmRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Mkl.Components/OlsLinearRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ private static void ProbCheckDecode(Double p)
Contracts.CheckDecode(0 <= p && p <= 1);
}

private static OlsModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static OlsModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.PCA/PcaTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
writer.WriteSinglesNoCount(_eigenVectors[i].GetValues().Slice(0, _dimension));
}

private static PcaModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static PcaModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ private MatrixFactorizationModelParameters(IHostEnvironment env, ModelLoadContex
/// <summary>
/// Load model from the given context
/// </summary>
private static MatrixFactorizationModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static MatrixFactorizationModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down Expand Up @@ -556,7 +556,7 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(MatrixFactorizationPredictionTransformer).Assembly.FullName);
}
private static MatrixFactorizationPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
internal static MatrixFactorizationPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
=> new MatrixFactorizationPredictionTransformer(env, ctx);

}
Expand Down
Loading

0 comments on commit e63fa8f

Please sign in to comment.