Skip to content

Change ensembles trainer to work with ITrainerEstimators instead of ITrainers #3796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 64 additions & 43 deletions src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ public abstract class ArgumentsBase
[TGUI(Label = "Validation Dataset Proportion")]
public Single ValidationDatasetProportion = 0.3f;

internal abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> GetPredictorFactory();
internal abstract IComponentFactory<ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<TOutput>>, IPredictorProducing<TOutput>>> GetPredictorFactory();
}

private protected readonly IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorType;
private protected readonly IComponentFactory<ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<TOutput>>, IPredictorProducing<TOutput>>> BasePredictorType;
private protected readonly IHost Host;
private protected IPredictorProducing<TOutput> Meta;

Expand Down Expand Up @@ -140,54 +140,75 @@ public void Train(List<FeatureSubsetModel<TOutput>> models, RoleMappedData data,
maps[i] = m.GetMapper<VBuffer<Single>, TOutput>();
}

// REVIEW: Should implement this better....
var labels = new Single[100];
var features = new VBuffer<Single>[100];
int count = 0;
// REVIEW: Should this include bad values or filter them?
using (var cursor = new FloatLabelCursor(data, CursOpt.AllFeatures | CursOpt.AllLabels))
var view = CreateDataView(host, ch, data, maps, models);
var trainer = BasePredictorType.CreateComponent(host);
if (trainer.Info.NeedNormalization)
ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this.");
Meta = trainer.Fit(view).Model;
CheckMeta();
}
}

private IDataView CreateDataView(IHostEnvironment env, IChannel ch, RoleMappedData data, ValueMapper<VBuffer<Single>,
TOutput>[] maps, List<FeatureSubsetModel<TOutput>> models)
{
switch (data.Schema.Label.Value.Type.GetRawKind())
{
case InternalDataKind.BL:
return CreateDataView<bool>(env, ch, data, maps, models, x => x > 0);
case InternalDataKind.R4:
return CreateDataView<float>(env, ch, data, maps, models, x => x);
case InternalDataKind.U4:
ch.Check(data.Schema.Label.Value.Type is KeyDataViewType);
return CreateDataView(env, ch, data, maps, models, x => float.IsNaN(x) ? 0 : (uint)(x + 1));
default:
throw ch.Except("Unsupported label type");
}
}

private IDataView CreateDataView<T>(IHostEnvironment env, IChannel ch, RoleMappedData data, ValueMapper<VBuffer<Single>, TOutput>[] maps,
List<FeatureSubsetModel<TOutput>> models, Func<float, T> labelConvert)
{
// REVIEW: Should implement this better....
var labels = new T[100];
var features = new VBuffer<Single>[100];
int count = 0;
// REVIEW: Should this include bad values or filter them?
using (var cursor = new FloatLabelCursor(data, CursOpt.AllFeatures | CursOpt.AllLabels))
{
TOutput[] predictions = new TOutput[maps.Length];
var vBuffers = new VBuffer<Single>[maps.Length];
while (cursor.MoveNext())
{
TOutput[] predictions = new TOutput[maps.Length];
var vBuffers = new VBuffer<Single>[maps.Length];
while (cursor.MoveNext())
Parallel.For(0, maps.Length, i =>
{
Parallel.For(0, maps.Length, i =>
var model = models[i];
if (model.SelectedFeatures != null)
{
var model = models[i];
if (model.SelectedFeatures != null)
{
EnsembleUtils.SelectFeatures(in cursor.Features, model.SelectedFeatures, model.Cardinality, ref vBuffers[i]);
maps[i](in vBuffers[i], ref predictions[i]);
}
else
maps[i](in cursor.Features, ref predictions[i]);
});

Utils.EnsureSize(ref labels, count + 1);
Utils.EnsureSize(ref features, count + 1);
labels[count] = cursor.Label;
FillFeatureBuffer(predictions, ref features[count]);
count++;
}
EnsembleUtils.SelectFeatures(in cursor.Features, model.SelectedFeatures, model.Cardinality, ref vBuffers[i]);
maps[i](in vBuffers[i], ref predictions[i]);
}
else
maps[i](in cursor.Features, ref predictions[i]);
});

Utils.EnsureSize(ref labels, count + 1);
Utils.EnsureSize(ref features, count + 1);
labels[count] = labelConvert(cursor.Label);
FillFeatureBuffer(predictions, ref features[count]);
count++;
}
}

ch.Info("The number of instances used for stacking trainer is {0}", count);

var bldr = new ArrayDataViewBuilder(host);
Array.Resize(ref labels, count);
Array.Resize(ref features, count);
bldr.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, labels);
bldr.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single, features);
ch.Info("The number of instances used for stacking trainer is {0}", count);

var view = bldr.GetDataView();
var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features);
var bldr = new ArrayDataViewBuilder(env);
Array.Resize(ref labels, count);
Array.Resize(ref features, count);
bldr.AddColumn(DefaultColumnNames.Label, data.Schema.Label.Value.Type as PrimitiveDataViewType, labels);
bldr.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single, features);

var trainer = BasePredictorType.CreateComponent(host);
if (trainer.Info.NeedNormalization)
ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this.");
Meta = trainer.Train(rmd);
CheckMeta();
}
return bldr.GetDataView();
}
}
}
7 changes: 4 additions & 3 deletions src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

namespace Microsoft.ML.Trainers.Ensemble
{
using TVectorPredictor = IPredictorProducing<VBuffer<Single>>;
using TVectorTrainer = ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<VBuffer<float>>>, IPredictorProducing<VBuffer<float>>>;

internal sealed class MultiStacking : BaseStacking<VBuffer<Single>>, IMulticlassOutputCombiner
{
public const string LoadName = "MultiStacking";
Expand All @@ -44,9 +45,9 @@ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerF
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMulticlassClassifierTrainer))]
[TGUI(Label = "Base predictor")]
public IComponentFactory<ITrainer<TVectorPredictor>> BasePredictorType;
public IComponentFactory<TVectorTrainer> BasePredictorType;

internal override IComponentFactory<ITrainer<TVectorPredictor>> GetPredictorFactory() => BasePredictorType;
internal override IComponentFactory<TVectorTrainer> GetPredictorFactory() => BasePredictorType;

public IMulticlassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiStacking(env, this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

namespace Microsoft.ML.Trainers.Ensemble
{
using TScalarPredictor = IPredictorProducing<Single>;
using TScalarTrainer = ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;

internal sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner
{
Expand All @@ -43,9 +43,9 @@ public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerF
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))]
[TGUI(Label = "Base predictor")]
public IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorType;
public IComponentFactory<TScalarTrainer> BasePredictorType;

internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType;
internal override IComponentFactory<TScalarTrainer> GetPredictorFactory() => BasePredictorType;

public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this);
}
Expand Down
7 changes: 4 additions & 3 deletions src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

namespace Microsoft.ML.Trainers.Ensemble
{
using TScalarPredictor = IPredictorProducing<Single>;
using TScalarTrainer = ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;

internal sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner
{
public const string UserName = "Stacking";
Expand All @@ -41,9 +42,9 @@ public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFacto
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
[TGUI(Label = "Base predictor")]
public IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorType;
public IComponentFactory<TScalarTrainer> BasePredictorType;

internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType;
internal override IComponentFactory<TScalarTrainer> GetPredictorFactory() => BasePredictorType;

public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this);
}
Expand Down
17 changes: 7 additions & 10 deletions src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ namespace Microsoft.ML.Trainers.Ensemble
{
using TDistPredictor = IDistPredictorProducing<Single, Single>;
using TScalarPredictor = IPredictorProducing<Single>;
using TScalarTrainer = ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;

/// <summary>
/// A generic ensemble trainer for binary classification.
/// </summary>
internal sealed class EnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredictor,
internal sealed class EnsembleTrainer : EnsembleTrainerBase<Single,
IBinarySubModelSelector, IBinaryOutputCombiner>,
IModelCombiner
{
Expand All @@ -47,20 +49,15 @@ public sealed class Arguments : ArgumentsBase

// REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer.
[Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
public IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictors;
public IComponentFactory<TScalarTrainer>[] BasePredictors;

internal override IComponentFactory<ITrainer<TScalarPredictor>>[] GetPredictorFactories() => BasePredictors;
internal override IComponentFactory<TScalarTrainer>[] GetPredictorFactories() => BasePredictors;

public Arguments()
{
BasePredictors = new[]
{
ComponentFactoryUtils.CreateFromFunction(
env => {
var trainerEstimator = new LinearSvmTrainer(env);
return TrainerUtils.MapTrainerEstimatorToTrainer<LinearSvmTrainer,
LinearBinaryModelParameters, LinearBinaryModelParameters>(env, trainerEstimator);
})
ComponentFactoryUtils.CreateFromFunction(env => new LinearSvmTrainer(env, LabelColumnName, FeatureColumnName))
};
}
}
Expand All @@ -83,7 +80,7 @@ private EnsembleTrainer(IHostEnvironment env, Arguments args, PredictionKind pre

private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

private protected override TScalarPredictor CreatePredictor(List<FeatureSubsetModel<float>> models)
private protected override IPredictor CreatePredictor(List<FeatureSubsetModel<float>> models)
{
if (models.All(m => m.Predictor is TDistPredictor))
return new EnsembleDistributionModelParameters(Host, PredictionKind, CreateModels<TDistPredictor>(models), Combiner);
Expand Down
20 changes: 10 additions & 10 deletions src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ namespace Microsoft.ML.Trainers.Ensemble
{
using Stopwatch = System.Diagnostics.Stopwatch;

internal abstract class EnsembleTrainerBase<TOutput, TPredictor, TSelector, TCombiner> : ITrainer<TPredictor>
where TPredictor : class, IPredictorProducing<TOutput>
internal abstract class EnsembleTrainerBase<TOutput, TSelector, TCombiner> : ITrainer<IPredictor>
where TSelector : class, ISubModelSelector<TOutput>
where TCombiner : class, IOutputCombiner<TOutput>
{
Expand Down Expand Up @@ -51,7 +50,7 @@ public abstract class ArgumentsBase : TrainerInputBaseWithLabel
[TGUI(Label = "Show Sub-Model Metrics")]
public bool ShowMetrics;

internal abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>>[] GetPredictorFactories();
internal abstract IComponentFactory<ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<TOutput>>, IPredictorProducing<TOutput>>>[] GetPredictorFactories();
#pragma warning restore CS0649
}

Expand All @@ -62,7 +61,7 @@ public abstract class ArgumentsBase : TrainerInputBaseWithLabel
private protected readonly IHost Host;

/// <summary> Ensemble members </summary>
private protected readonly ITrainer<IPredictorProducing<TOutput>>[] Trainers;
private protected readonly ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<TOutput>>, IPredictorProducing<TOutput>>[] Trainers;

private readonly ISubsetSelector _subsetSelector;
private protected ISubModelSelector<TOutput> SubModelSelector;
Expand Down Expand Up @@ -95,7 +94,7 @@ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env,

_subsetSelector = Args.SamplingType.CreateComponent(Host);

Trainers = new ITrainer<IPredictorProducing<TOutput>>[NumModels];
Trainers = new ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<TOutput>>, IPredictorProducing<TOutput>>[NumModels];
for (int i = 0; i < Trainers.Length; i++)
Trainers[i] = predictorFactories[i % predictorFactories.Length].CreateComponent(Host);
// We infer normalization and calibration preferences from the trainers. However, even if the internal trainers
Expand All @@ -106,7 +105,7 @@ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env,
}
}

TPredictor ITrainer<TPredictor>.Train(TrainContext context)
IPredictor ITrainer<IPredictor>.Train(TrainContext context)
{
Host.CheckValue(context, nameof(context));

Expand All @@ -117,9 +116,9 @@ TPredictor ITrainer<TPredictor>.Train(TrainContext context)
}

IPredictor ITrainer.Train(TrainContext context)
=> ((ITrainer<TPredictor>)this).Train(context);
=> ((ITrainer<IPredictor>)this).Train(context);

private TPredictor TrainCore(IChannel ch, RoleMappedData data)
private IPredictor TrainCore(IChannel ch, RoleMappedData data)
{
Host.AssertValue(ch);
ch.AssertValue(data);
Expand Down Expand Up @@ -155,8 +154,9 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data)
{
if (EnsureMinimumFeaturesSelected(subset))
{
// REVIEW: How to pass the role mappings to the trainer?
var model = new FeatureSubsetModel<TOutput>(
Trainers[(int)index].Train(subset.Data),
Trainers[(int)index].Fit(subset.Data.Data).Model,
subset.SelectedFeatures,
null);
SubModelSelector.CalculateMetrics(model, _subsetSelector, subset, batch, needMetrics);
Expand Down Expand Up @@ -190,7 +190,7 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data)
return CreatePredictor(models);
}

private protected abstract TPredictor CreatePredictor(List<FeatureSubsetModel<TOutput>> models);
private protected abstract IPredictor CreatePredictor(List<FeatureSubsetModel<TOutput>> models);

private bool EnsureMinimumFeaturesSelected(Subset subset)
{
Expand Down
Loading