Skip to content

Moving IModelCombiner to Ensemble and related changes #1563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions src/Microsoft.ML.Core/Prediction/ITrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.ML.Runtime.Data;

namespace Microsoft.ML.Runtime
Expand Down Expand Up @@ -94,12 +93,4 @@ public static IPredictor Train(this ITrainer trainer, RoleMappedData trainData)
public static TPredictor Train<TPredictor>(this ITrainer<TPredictor> trainer, RoleMappedData trainData) where TPredictor : IPredictor
=> trainer.Train(new TrainContext(trainData));
}

/// <summary>
/// An interface that combines multiple predictors into a single predictor.
/// </summary>
public interface IModelCombiner
{
IPredictor CombineModels(IEnumerable<IPredictor> models);
}
}
1 change: 1 addition & 0 deletions src/Microsoft.ML.Core/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.ML;

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TestFramework" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.TestValue)]

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Legacy" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Maml" + PublicKey.Value)]
Expand Down
1 change: 0 additions & 1 deletion src/Microsoft.ML.Ensemble/Microsoft.ML.Ensemble.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
<ProjectReference Include="..\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
<ProjectReference Include="..\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" />
<ProjectReference Include="..\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" />
</ItemGroup>

</Project>
13 changes: 0 additions & 13 deletions src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Runtime.Model;

[assembly: LoadableClass(typeof(MultiStacking), typeof(MultiStacking.Arguments), typeof(SignatureCombiner),
Expand Down Expand Up @@ -50,17 +48,6 @@ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerF
internal override IComponentFactory<ITrainer<TVectorPredictor>> GetPredictorFactory() => BasePredictorType;

public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiStacking(env, this);

public Arguments()
{
// REVIEW: Perhaps we can have a better non-parametetric learner.
BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
env => new Ova(env, new Ova.Arguments()
{
PredictorType = ComponentFactoryUtils.CreateFromFunction(
e => new FastTreeBinaryClassificationTrainer(e, DefaultColumnNames.Label, DefaultColumnNames.Features))
}));
}
}

public MultiStacking(IHostEnvironment env, Arguments args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
using System;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Model;

Expand Down Expand Up @@ -47,12 +45,6 @@ public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerF

internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType;

public Arguments()
{
BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
env => new FastTreeRegressionTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features));
}

public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this);
}

Expand Down
8 changes: 0 additions & 8 deletions src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
using System;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Trainers.FastTree;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Model;

Expand Down Expand Up @@ -45,12 +43,6 @@ public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFacto

internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType;

public Arguments()
{
BasePredictorType = ComponentFactoryUtils.CreateFromFunction(
env => new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features));
}

public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this);
}

Expand Down
17 changes: 17 additions & 0 deletions src/Microsoft.ML.Ensemble/Trainer/IModelCombiner.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.ML.Runtime;

namespace Microsoft.ML.Runtime.Ensemble
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) I'd prefer if we started matching folders and namespaces as much as possible. It makes finding files easier (just like if the file name and the class name match).

This file is in the Trainer folder, but not in a Trainer namespace.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed. I am making it consistent with the files here (so in a limited, myopic sense my action here was correct), but that does not change the fact that the system, such as it is, is slapdash and haphazard to the point where it's mostly futile to try to find anything without just a broad search. 😛 Let us open an issue on this, I will try to do so before I have to get the kids ready for school.

Copy link
Contributor Author

@TomFinley TomFinley Nov 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is not such a simple matter -- we first need to decide what those namespaces will be. Obviously it won't be Microsoft.ML.Runtime.Ensemble, but what? Microsoft.ML.Ensemble? Maybe there's already a proposal open for all I know. We have namespaces outlined for specific components I believe, but not a general principle.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems very reasonable to me that the assembly name and the root namespace match (that's the default in .csproj files). So Microsoft.ML.Ensemble sounds like a good proposal to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either this, or Microsoft.ML.Borborygmization, I accept nothing else


In reply to: 231567175 [](ancestors = 231567175)

{
/// <summary>
/// An interface that combines multiple predictors into a single predictor.
/// </summary>
public interface IModelCombiner
{
IPredictor CombineModels(IEnumerable<IPredictor> models);
}
}
16 changes: 8 additions & 8 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
{
protected readonly TArgs Args;
protected readonly bool AllowGC;
protected Ensemble TrainedEnsemble;
protected TreeEnsemble TrainedEnsemble;
protected int FeatureCount;
protected RoleMappedData ValidData;
protected IParallelTraining ParallelTraining;
Expand All @@ -76,7 +76,7 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
protected double[] InitValidScores;
protected double[][] InitTestScores;
//protected int Iteration;
protected Ensemble Ensemble;
protected TreeEnsemble Ensemble;

protected bool HasValidSet => ValidSet != null;

Expand Down Expand Up @@ -478,7 +478,7 @@ protected bool AreSamplesWeighted(IChannel ch)

private void InitializeEnsemble()
{
Ensemble = new Ensemble();
Ensemble = new TreeEnsemble();
}

/// <summary>
Expand Down Expand Up @@ -914,7 +914,7 @@ internal abstract class DataConverter
/// of features we actually trained on. This can be null in the event that no filtering
/// occurred.
/// </summary>
/// <seealso cref="Ensemble.RemapFeatures"/>
/// <seealso cref="TreeEnsemble.RemapFeatures"/>
public int[] FeatureMap;

protected readonly IHost Host;
Expand Down Expand Up @@ -2810,7 +2810,7 @@ public abstract class FastTreePredictionWrapper :
ISingleCanSaveOnnx
{
//The below two properties are necessary for tree Visualizer
public Ensemble TrainedEnsemble { get; }
public TreeEnsemble TrainedEnsemble { get; }
public int NumTrees => TrainedEnsemble.NumTrees;

// Inner args is used only for documentation purposes when saving comments to INI files.
Expand All @@ -2834,7 +2834,7 @@ public abstract class FastTreePredictionWrapper :
public bool CanSavePfa => true;
public bool CanSaveOnnx(OnnxContext ctx) => true;

protected FastTreePredictionWrapper(IHostEnvironment env, string name, Ensemble trainedEnsemble, int numFeatures, string innerArgs)
protected FastTreePredictionWrapper(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
: base(env, name)
{
Host.CheckValue(trainedEnsemble, nameof(trainedEnsemble));
Expand Down Expand Up @@ -2871,7 +2871,7 @@ protected FastTreePredictionWrapper(IHostEnvironment env, string name, ModelLoad
if (ctx.Header.ModelVerWritten >= VerCategoricalSplitSerialized)
categoricalSplits = true;

TrainedEnsemble = new Ensemble(ctx, usingDefaultValues, categoricalSplits);
TrainedEnsemble = new TreeEnsemble(ctx, usingDefaultValues, categoricalSplits);
MaxSplitFeatIdx = FindMaxFeatureIndex(TrainedEnsemble);

InnerArgs = ctx.LoadStringOrNull();
Expand Down Expand Up @@ -3264,7 +3264,7 @@ public void GetFeatureWeights(ref VBuffer<Float> weights)
bldr.GetResult(ref weights);
}

private static int FindMaxFeatureIndex(Ensemble ensemble)
private static int FindMaxFeatureIndex(TreeEnsemble ensemble)
{
int ifeatMax = 0;
for (int i = 0; i < ensemble.NumTrees; i++)
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ private static VersionInfo GetVersionInfo()

protected override uint VerCategoricalSplitSerialized => 0x00010005;

internal FastTreeBinaryPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs)
internal FastTreeBinaryPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
{
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeRanking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,7 @@ private static VersionInfo GetVersionInfo()

protected override uint VerCategoricalSplitSerialized => 0x00010005;

internal FastTreeRankingPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs)
internal FastTreeRankingPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
{
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ private static VersionInfo GetVersionInfo()

protected override uint VerCategoricalSplitSerialized => 0x00010005;

internal FastTreeRegressionPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs)
internal FastTreeRegressionPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
{
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeTweedie.cs
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ private static VersionInfo GetVersionInfo()

protected override uint VerCategoricalSplitSerialized => 0x00010003;

internal FastTreeTweediePredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs)
internal FastTreeTweediePredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
{
}
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.FastTree/Microsoft.ML.FastTree.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj" />
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
<ProjectReference Include="..\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj" />
</ItemGroup>

</Project>
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/RandomForestClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ private static VersionInfo GetVersionInfo()
/// </summary>
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

public FastForestClassificationPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs)
public FastForestClassificationPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
{ }

Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/RandomForestRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private static VersionInfo GetVersionInfo()

protected override uint VerCategoricalSplitSerialized => 0x00010006;

public FastForestRegressionPredictor(IHostEnvironment env, Ensemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount)
public FastForestRegressionPredictor(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount)
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
{
_quantileSampleCount = samplesCount;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/Training/BaggingProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public int GetBagCount(int numTrees, int bagSize)
// Divides output values of leaves to bag count.
// This brings back the final scores generated by model on a same
// range as when we didn't use bagging
public void ScaleEnsembleLeaves(int numTrees, int bagSize, Ensemble ensemble)
public void ScaleEnsembleLeaves(int numTrees, int bagSize, TreeEnsemble ensemble)
{
int bagCount = GetBagCount(numTrees, bagSize);
for (int t = 0; t < ensemble.NumTrees; t++)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ public interface IEnsembleCompressor<TLabel>

void SetTreeScores(int idx, double[] scores);

bool Compress(IChannel ch, Ensemble ensemble, double[] trainScores, int bestIteration, int maxTreesAfterCompression);
bool Compress(IChannel ch, TreeEnsemble ensemble, double[] trainScores, int bestIteration, int maxTreesAfterCompression);

Ensemble GetCompressedEnsemble();
TreeEnsemble GetCompressedEnsemble();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public class LassoBasedEnsembleCompressor : IEnsembleCompressor<short>

private Dataset _trainSet;
private short[] _labels;
private Ensemble _compressedEnsemble;
private TreeEnsemble _compressedEnsemble;
private int[] _sampleObservationIndices;
private Random _rnd;

Expand Down Expand Up @@ -458,9 +458,9 @@ private LassoFit GetLassoFit(IChannel ch, int maxAllowedFeaturesPerModel)
return fit;
}

private Ensemble GetEnsembleFromSolution(LassoFit fit, int solutionIdx, Ensemble originalEnsemble)
private TreeEnsemble GetEnsembleFromSolution(LassoFit fit, int solutionIdx, TreeEnsemble originalEnsemble)
{
Ensemble ensemble = new Ensemble();
TreeEnsemble ensemble = new TreeEnsemble();

int weightsCount = fit.NumberOfWeights[solutionIdx];
for (int i = 0; i < weightsCount; i++)
Expand Down Expand Up @@ -534,7 +534,7 @@ private unsafe void LoadTargets(double[] trainScores, int bestIteration)
}
}

public bool Compress(IChannel ch, Ensemble ensemble, double[] trainScores, int bestIteration, int maxTreesAfterCompression)
public bool Compress(IChannel ch, TreeEnsemble ensemble, double[] trainScores, int bestIteration, int maxTreesAfterCompression)
{
LoadTargets(trainScores, bestIteration);

Expand All @@ -552,7 +552,7 @@ public bool Compress(IChannel ch, Ensemble ensemble, double[] trainScores, int b
return true;
}

public Ensemble GetCompressedEnsemble()
public TreeEnsemble GetCompressedEnsemble()
{
return _compressedEnsemble;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Microsoft.ML.Trainers.FastTree.Internal
//Accelerated gradient descent score tracker
public class AcceleratedGradientDescent : GradientDescent
{
public AcceleratedGradientDescent(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
public AcceleratedGradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
: base(ensemble, trainData, initTrainScores, gradientWrapper)
{
UseFastTrainingScoresUpdate = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public class ConjugateGradientDescent : GradientDescent
private double[] _currentGradient;
private double[] _currentDk;

public ConjugateGradientDescent(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
public ConjugateGradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
: base(ensemble, trainData, initTrainScores, gradientWrapper)
{
_currentDk = new double[trainData.NumDocs];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class GradientDescent : OptimizationAlgorithm
private double[] _droppedScores;
private double[] _scores;

public GradientDescent(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
public GradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
: base(ensemble, trainData, initTrainScores)
{
_gradientWrapper = gradientWrapper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public class RandomForestOptimizer : GradientDescent
{
private IGradientAdjuster _gradientWrapper;
// REVIEW: When the FastTree appliation is decoupled with tree learner and boosting logic, this class should be removed.
public RandomForestOptimizer(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
public RandomForestOptimizer(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
: base(ensemble, trainData, initTrainScores, gradientWrapper)
{
_gradientWrapper = gradientWrapper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public abstract class OptimizationAlgorithm
public delegate void PreScoreUpdateHandler(IChannel ch);
public PreScoreUpdateHandler PreScoreUpdateEvent;

public Ensemble Ensemble;
public TreeEnsemble Ensemble;

public ScoreTracker TrainingScores;
public List<ScoreTracker> TrackedScores;
Expand All @@ -37,7 +37,7 @@ public abstract class OptimizationAlgorithm
public Random DropoutRng;
public bool UseFastTrainingScoresUpdate;

public OptimizationAlgorithm(Ensemble ensemble, Dataset trainData, double[] initTrainScores)
public OptimizationAlgorithm(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores)
{
Ensemble = ensemble;
TrainingScores = ConstructScoreTracker("train", trainData, initTrainScores);
Expand Down
Loading