Skip to content

Commit e63fa8f

Browse files
Fixes #4385 about calling the Create methods when loading models from disk (#4485)
* Changed ComponentCatalog so it would use the most public "initter" * Changed the visibility of several constructors and create methods so to choose the correct initter when loading models from disk
1 parent 549b389 commit e63fa8f

File tree

41 files changed

+150
-64
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+150
-64
lines changed

src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,54 @@
1313

1414
namespace Microsoft.ML.Runtime
1515
{
16+
17+
internal static class Extension
18+
{
19+
internal static AccessModifier Accessmodifier(this MethodInfo methodInfo)
20+
{
21+
if (methodInfo.IsFamilyAndAssembly)
22+
return AccessModifier.PrivateProtected;
23+
if (methodInfo.IsPrivate)
24+
return AccessModifier.Private;
25+
if (methodInfo.IsFamily)
26+
return AccessModifier.Protected;
27+
if (methodInfo.IsFamilyOrAssembly)
28+
return AccessModifier.ProtectedInternal;
29+
if (methodInfo.IsAssembly)
30+
return AccessModifier.Internal;
31+
if (methodInfo.IsPublic)
32+
return AccessModifier.Public;
33+
throw new ArgumentException("Did not find access modifier", "methodInfo");
34+
}
35+
36+
internal static AccessModifier Accessmodifier(this ConstructorInfo constructorInfo)
37+
{
38+
if (constructorInfo.IsFamilyAndAssembly)
39+
return AccessModifier.PrivateProtected;
40+
if (constructorInfo.IsPrivate)
41+
return AccessModifier.Private;
42+
if (constructorInfo.IsFamily)
43+
return AccessModifier.Protected;
44+
if (constructorInfo.IsFamilyOrAssembly)
45+
return AccessModifier.ProtectedInternal;
46+
if (constructorInfo.IsAssembly)
47+
return AccessModifier.Internal;
48+
if (constructorInfo.IsPublic)
49+
return AccessModifier.Public;
50+
throw new ArgumentException("Did not find access modifier", "constructorInfo");
51+
}
52+
53+
internal enum AccessModifier
54+
{
55+
PrivateProtected,
56+
Private,
57+
Protected,
58+
ProtectedInternal,
59+
Internal,
60+
Public
61+
}
62+
}
63+
1664
/// <summary>
1765
/// This catalogs instantiatable components (aka, loadable classes). Components are registered via
1866
/// a descendant of <see cref="LoadableClassAttributeBase"/>, identifying the names and signature types under which the component
@@ -414,21 +462,59 @@ private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTyp
414462
ctor = null;
415463
create = null;
416464
requireEnvironment = false;
465+
bool requireEnvironmentCtor = false;
466+
bool requireEnvironmentCreate = false;
417467
var parmTypesWithEnv = Utils.Concat(new Type[1] { typeof(IHostEnvironment) }, parmTypes);
468+
418469
if (Utils.Size(parmTypes) == 0 && (getter = FindInstanceGetter(instType, loaderType)) != null)
419470
return true;
420-
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypes ?? Type.EmptyTypes, null)) != null)
421-
return true;
422-
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypesWithEnv ?? Type.EmptyTypes, null)) != null)
471+
472+
// Find both 'ctor' and 'create' methods if available
473+
if (instType.IsAssignableFrom(loaderType))
474+
{
475+
if ((ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypes ?? Type.EmptyTypes, null)) == null)
476+
{
477+
if ((ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypesWithEnv ?? Type.EmptyTypes, null)) != null)
478+
requireEnvironmentCtor = true;
479+
}
480+
}
481+
482+
if ((create = FindCreateMethod(instType, loaderType, parmTypes ?? Type.EmptyTypes)) == null)
423483
{
424-
requireEnvironment = true;
484+
if ((create = FindCreateMethod(instType, loaderType, parmTypesWithEnv ?? Type.EmptyTypes)) != null)
485+
requireEnvironmentCreate = true;
486+
}
487+
488+
if (ctor != null && create != null)
489+
{
490+
// If both 'ctor' and 'create' methods were found
491+
// Choose the one that is 'more' public
492+
// If they have the same visibility, then throw an exception, since this shouldn't happen.
493+
494+
if (ctor.Accessmodifier() == create.Accessmodifier())
495+
{
496+
throw Contracts.Except($"Can't load type {instType}, because it has both create and constructor methods with the same visibility. Please indicate which one should be used by changing either the signature or the visibility of one of them.");
497+
}
498+
if (ctor.Accessmodifier() > create.Accessmodifier())
499+
{
500+
create = null;
501+
requireEnvironment = requireEnvironmentCtor;
502+
return true;
503+
}
504+
ctor = null;
505+
requireEnvironment = requireEnvironmentCreate;
425506
return true;
426507
}
427-
if ((create = FindCreateMethod(instType, loaderType, parmTypes ?? Type.EmptyTypes)) != null)
508+
509+
if (ctor != null && create == null)
510+
{
511+
requireEnvironment = requireEnvironmentCtor;
428512
return true;
429-
if ((create = FindCreateMethod(instType, loaderType, parmTypesWithEnv ?? Type.EmptyTypes)) != null)
513+
}
514+
515+
if (ctor == null && create != null)
430516
{
431-
requireEnvironment = true;
517+
requireEnvironment = requireEnvironmentCreate;
432518
return true;
433519
}
434520

src/Microsoft.ML.Data/Prediction/Calibrator.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,7 @@ private SchemaBindableCalibratedModelParameters(IHostEnvironment env, ModelLoadC
760760
_featureContribution = SubModel as IFeatureContributionMapper;
761761
}
762762

763-
private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx)
763+
internal static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx)
764764
{
765765
Contracts.CheckValue(ctx, nameof(ctx));
766766
ctx.CheckAtModel(GetVersionInfo());
@@ -1224,7 +1224,7 @@ private NaiveCalibrator(IHostEnvironment env, ModelLoadContext ctx)
12241224
_host.CheckDecode(_binProbs.All(x => (0 <= x && x <= 1)));
12251225
}
12261226

1227-
private static NaiveCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
1227+
internal static NaiveCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
12281228
{
12291229
Contracts.CheckValue(env, nameof(env));
12301230
env.CheckValue(ctx, nameof(ctx));
@@ -1675,7 +1675,7 @@ private PlattCalibrator(IHostEnvironment env, ModelLoadContext ctx)
16751675
_host.CheckDecode(FloatUtils.IsFinite(Offset));
16761676
}
16771677

1678-
private static PlattCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
1678+
internal static PlattCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
16791679
{
16801680
Contracts.CheckValue(env, nameof(env));
16811681
env.CheckValue(ctx, nameof(ctx));
@@ -1972,7 +1972,7 @@ private IsotonicCalibrator(IHostEnvironment env, ModelLoadContext ctx)
19721972
_host.CheckDecode(valuePrev <= 1);
19731973
}
19741974

1975-
private static IsotonicCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
1975+
internal static IsotonicCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
19761976
{
19771977
Contracts.CheckValue(env, nameof(env));
19781978
env.CheckValue(ctx, nameof(ctx));

src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
148148
}
149149

150150
// Factory method for SignatureLoadModel.
151-
private static FeatureContributionCalculatingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
151+
internal static FeatureContributionCalculatingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
152152
{
153153
Contracts.CheckValue(env, nameof(env));
154154
ctx.CheckAtModel(GetVersionInfo());

src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ public LabelIndicatorTransform(IHostEnvironment env,
130130
{
131131
}
132132

133-
public LabelIndicatorTransform(IHostEnvironment env, Options options, IDataView input)
133+
internal LabelIndicatorTransform(IHostEnvironment env, Options options, IDataView input)
134134
: base(env, LoadName, Contracts.CheckRef(options, nameof(options)).Columns,
135135
input, TestIsMulticlassLabel)
136136
{

src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ private SkipTakeFilter(long skip, long take, IHostEnvironment env, IDataView inp
101101
/// <param name="env">Host Environment.</param>
102102
/// <param name="options">Options for the skip operation.</param>
103103
/// <param name="input">Input <see cref="IDataView"/>.</param>
104-
public SkipTakeFilter(IHostEnvironment env, SkipOptions options, IDataView input)
104+
internal SkipTakeFilter(IHostEnvironment env, SkipOptions options, IDataView input)
105105
: this(options.Count, Options.DefaultTake, env, input)
106106
{
107107
}
@@ -112,7 +112,7 @@ public SkipTakeFilter(IHostEnvironment env, SkipOptions options, IDataView input
112112
/// <param name="env">Host Environment.</param>
113113
/// <param name="options">Options for the take operation.</param>
114114
/// <param name="input">Input <see cref="IDataView"/>.</param>
115-
public SkipTakeFilter(IHostEnvironment env, TakeOptions options, IDataView input)
115+
internal SkipTakeFilter(IHostEnvironment env, TakeOptions options, IDataView input)
116116
: this(Options.DefaultSkip, options.Count, env, input)
117117
{
118118
}

src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ private SlotsDroppingTransformer(IHostEnvironment env, ModelLoadContext ctx)
302302
}
303303

304304
// Factory method for SignatureLoadModel.
305-
private static SlotsDroppingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
305+
internal static SlotsDroppingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
306306
{
307307
Contracts.CheckValue(env, nameof(env));
308308
ctx.CheckAtModel(GetVersionInfo());

src/Microsoft.ML.Ensemble/Trainer/EnsembleDistributionModelParameters.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ private bool IsValid(IValueMapperDist mapper, out VectorDataViewType inputType)
119119
}
120120
}
121121

122-
private static EnsembleDistributionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
122+
internal static EnsembleDistributionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
123123
{
124124
Contracts.CheckValue(env, nameof(env));
125125
env.CheckValue(ctx, nameof(ctx));

src/Microsoft.ML.Ensemble/Trainer/EnsembleModelParameters.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ private bool IsValid(IValueMapper mapper, out VectorDataViewType inputType)
109109
}
110110
}
111111

112-
private static EnsembleModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
112+
internal static EnsembleModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
113113
{
114114
Contracts.CheckValue(env, nameof(env));
115115
env.CheckValue(ctx, nameof(ctx));

src/Microsoft.ML.Ensemble/Trainer/Multiclass/EnsembleMulticlassModelParameters.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ private void InitializeMappers(out IValueMapper[] mappers, out VectorDataViewTyp
9191
inputType = new VectorDataViewType(NumberDataViewType.Single);
9292
}
9393

94-
private static EnsembleMulticlassModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
94+
internal static EnsembleMulticlassModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
9595
{
9696
Contracts.CheckValue(env, nameof(env));
9797
env.CheckValue(ctx, nameof(ctx));

src/Microsoft.ML.FastTree/FastTreeClassification.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
8484
ctx.SetVersionInfo(GetVersionInfo());
8585
}
8686

87-
private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
87+
internal static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
8888
{
8989
Contracts.CheckValue(env, nameof(env));
9090
env.CheckValue(ctx, nameof(ctx));

0 commit comments

Comments
 (0)