Skip to content

Commit

Permalink
Enabling DI framework to scan the constructors with non-public visibi…
Browse files Browse the repository at this point in the history
…lity

* enabling scanning the constructors with non-public visibility, and reducing the visibility of some of them to avoid confusing the users.
  • Loading branch information
sfilipi authored and TomFinley committed Sep 20, 2018
1 parent 044a6d3 commit 6812cb5
Show file tree
Hide file tree
Showing 17 changed files with 18 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,9 @@ private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTyp
var parmTypesWithEnv = Utils.Concat(new Type[1] { typeof(IHostEnvironment) }, parmTypes);
if (Utils.Size(parmTypes) == 0 && (getter = FindInstanceGetter(instType, loaderType)) != null)
return true;
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(parmTypes ?? Type.EmptyTypes)) != null)
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypes ?? Type.EmptyTypes, null)) != null)
return true;
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(parmTypesWithEnv ?? Type.EmptyTypes)) != null)
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypesWithEnv ?? Type.EmptyTypes, null)) != null)
{
requireEnvironment = true;
return true;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelCol
/// <summary>
/// Initializes a new instance of <see cref="FastTreeBinaryClassificationTrainer"/> by using the legacy <see cref="Arguments"/> class.
/// </summary>
public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments args)
internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Arguments args)
: base(env, args, MakeLabelColumn(args.LabelColumn))
{
_outputColumns = new[]
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeRanking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public FastTreeRankingTrainer(IHostEnvironment env, string labelColumn, string f
/// <summary>
/// Initializes a new instance of <see cref="FastTreeRankingTrainer"/> by using the legacy <see cref="Arguments"/> class.
/// </summary>
public FastTreeRankingTrainer(IHostEnvironment env, Arguments args)
internal FastTreeRankingTrainer(IHostEnvironment env, Arguments args)
: base(env, args, MakeLabelColumn(args.LabelColumn))
{
_outputColumns = new[]
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn, strin
/// <summary>
/// Initializes a new instance of <see cref="FastTreeRegressionTrainer"/> by using the legacy <see cref="Arguments"/> class.
/// </summary>
public FastTreeRegressionTrainer(IHostEnvironment env, Arguments args)
internal FastTreeRegressionTrainer(IHostEnvironment env, Arguments args)
: base(env, args, MakeLabelColumn(args.LabelColumn))
{
_outputColumns = new[]
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeTweedie.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public FastTreeTweedieTrainer(IHostEnvironment env, string labelColumn, string f
/// <summary>
/// Initializes a new instance of <see cref="FastTreeTweedieTrainer"/> by using the legacy <see cref="Arguments"/> class.
/// </summary>
public FastTreeTweedieTrainer(IHostEnvironment env, Arguments args)
internal FastTreeTweedieTrainer(IHostEnvironment env, Arguments args)
: base(env, args, MakeLabelColumn(args.LabelColumn))
{
Initialize();
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/GamClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public sealed class Arguments : ArgumentsBase
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
private protected override bool NeedCalibration => true;

public BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args)
internal BinaryClassificationGamTrainer(IHostEnvironment env, Arguments args)
: base(env, args)
{
_sigmoidParameter = 1;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/GamRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public partial class Arguments : ArgumentsBase

public override PredictionKind PredictionKind => PredictionKind.Regression;

public RegressionGamTrainer(IHostEnvironment env, Arguments args)
internal RegressionGamTrainer(IHostEnvironment env, Arguments args)
: base(env, args) { }

internal override void CheckLabel(RoleMappedData data)
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/RandomForestClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ public FastForestClassification(IHostEnvironment env, string labelColumn, string
/// <summary>
/// Initializes a new instance of <see cref="FastForestClassification"/> by using the legacy <see cref="Arguments"/> class.
/// </summary>
public FastForestClassification(IHostEnvironment env, Arguments args)
internal FastForestClassification(IHostEnvironment env, Arguments args)
: base(env, args, MakeLabelColumn(args.LabelColumn))
{
_outputColumns = new[]
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/RandomForestRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ public FastForestRegression(IHostEnvironment env, string labelColumn, string fea
/// <summary>
/// Initializes a new instance of <see cref="FastForestRegression"/> by using the legacy <see cref="Arguments"/> class.
/// </summary>
public FastForestRegression(IHostEnvironment env, Arguments args)
internal FastForestRegression(IHostEnvironment env, Arguments args)
: base(env, args, MakeLabelColumn(args.LabelColumn), true)
{
_outputColumns = new[]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public sealed class Arguments : ArgumentsBase
/// Developers should instantiate <see cref="Pkpd"/> by supplying the trainer argument directly to the <see cref="Pkpd"/> constructor
/// using the other public constructor.
/// </summary>
public Pkpd(IHostEnvironment env, Arguments args)
internal Pkpd(IHostEnvironment env, Arguments args)
: base(env, args, LoadNameValue)
{
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
};
}

public SdcaMultiClassTrainer(IHostEnvironment env, Arguments args)
internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args)
: this(env, args, args.FeatureColumn, args.LabelColumn)
{
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featur
};
}

public SdcaRegressionTrainer(IHostEnvironment env, Arguments args)
internal SdcaRegressionTrainer(IHostEnvironment env, Arguments args)
: this(env, args, args.FeatureColumn, args.LabelColumn)
{
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public void TrainSentiment()
}, text);

// Train
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments() { MaxIterations = 20 });
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments() { MaxIterations = 20 }, "Features", "Label");
var trainRoles = new RoleMappedData(trans, label: "Label", feature: "Features");

var predicted = trainer.Train(trainRoles);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void DecomposableTrainAndPredict()
var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename)));
var term = TermTransform.Create(env, loader, "Label");
var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term);
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 });
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label");

IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, concat, prefetch: null) : concat;
var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void New_DecomposableTrainAndPredict()
var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
var term = TermTransform.Create(env, loader, "Label");
var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term);
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 });
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label");

IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, concat, prefetch: null) : concat;
var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");
Expand Down
2 changes: 1 addition & 1 deletion test/Microsoft.ML.Tests/Scenarios/Api/Extensibility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void Extensibility()
var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
.Transform(term);

var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 });
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label");

IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, concat, prefetch: null) : concat;
var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public void TrainAndPredictIrisModelUsingDirectInstantiationTest()
pipeline = NormalizeTransform.CreateMinMaxNormalizer(env, pipeline, "Features");

// Train
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments() { NumThreads = 1 } );
var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments() { NumThreads = 1 }, "Features", "Label");

// Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto
var cached = new CacheDataView(env, pipeline, prefetch: null);
Expand Down

0 comments on commit 6812cb5

Please sign in to comment.