Skip to content

Commit facd1b0

Browse files
yaeldMSDmitry-A
authored andcommitted
Change ensembles trainer to work with ITrainerEstimators instead of ITrainers (dotnet#3796)
* Change ensembles trainer to work with ITrainerEstimators instead of ITrainers * Use NumberParseOptions.UseSingle when comparing baselines * Decrease number of digits of precision * Decrease number of digits of precision in some tests * Disable some tests on netcoreapp3.0
1 parent 114eab3 commit facd1b0

File tree

60 files changed

+11184
-98
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+11184
-98
lines changed

src/Microsoft.ML.Ensemble/OutputCombiners/BaseStacking.cs

Lines changed: 64 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ public abstract class ArgumentsBase
2222
[TGUI(Label = "Validation Dataset Proportion")]
2323
public Single ValidationDatasetProportion = 0.3f;
2424

25-
internal abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> GetPredictorFactory();
25+
internal abstract IComponentFactory<ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<TOutput>>, IPredictorProducing<TOutput>>> GetPredictorFactory();
2626
}
2727

28-
private protected readonly IComponentFactory<ITrainer<IPredictorProducing<TOutput>>> BasePredictorType;
28+
private protected readonly IComponentFactory<ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<TOutput>>, IPredictorProducing<TOutput>>> BasePredictorType;
2929
private protected readonly IHost Host;
3030
private protected IPredictorProducing<TOutput> Meta;
3131

@@ -140,54 +140,75 @@ public void Train(List<FeatureSubsetModel<TOutput>> models, RoleMappedData data,
140140
maps[i] = m.GetMapper<VBuffer<Single>, TOutput>();
141141
}
142142

143-
// REVIEW: Should implement this better....
144-
var labels = new Single[100];
145-
var features = new VBuffer<Single>[100];
146-
int count = 0;
147-
// REVIEW: Should this include bad values or filter them?
148-
using (var cursor = new FloatLabelCursor(data, CursOpt.AllFeatures | CursOpt.AllLabels))
143+
var view = CreateDataView(host, ch, data, maps, models);
144+
var trainer = BasePredictorType.CreateComponent(host);
145+
if (trainer.Info.NeedNormalization)
146+
ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this.");
147+
Meta = trainer.Fit(view).Model;
148+
CheckMeta();
149+
}
150+
}
151+
152+
private IDataView CreateDataView(IHostEnvironment env, IChannel ch, RoleMappedData data, ValueMapper<VBuffer<Single>,
153+
TOutput>[] maps, List<FeatureSubsetModel<TOutput>> models)
154+
{
155+
switch (data.Schema.Label.Value.Type.GetRawKind())
156+
{
157+
case InternalDataKind.BL:
158+
return CreateDataView<bool>(env, ch, data, maps, models, x => x > 0);
159+
case InternalDataKind.R4:
160+
return CreateDataView<float>(env, ch, data, maps, models, x => x);
161+
case InternalDataKind.U4:
162+
ch.Check(data.Schema.Label.Value.Type is KeyDataViewType);
163+
return CreateDataView(env, ch, data, maps, models, x => float.IsNaN(x) ? 0 : (uint)(x + 1));
164+
default:
165+
throw ch.Except("Unsupported label type");
166+
}
167+
}
168+
169+
private IDataView CreateDataView<T>(IHostEnvironment env, IChannel ch, RoleMappedData data, ValueMapper<VBuffer<Single>, TOutput>[] maps,
170+
List<FeatureSubsetModel<TOutput>> models, Func<float, T> labelConvert)
171+
{
172+
// REVIEW: Should implement this better....
173+
var labels = new T[100];
174+
var features = new VBuffer<Single>[100];
175+
int count = 0;
176+
// REVIEW: Should this include bad values or filter them?
177+
using (var cursor = new FloatLabelCursor(data, CursOpt.AllFeatures | CursOpt.AllLabels))
178+
{
179+
TOutput[] predictions = new TOutput[maps.Length];
180+
var vBuffers = new VBuffer<Single>[maps.Length];
181+
while (cursor.MoveNext())
149182
{
150-
TOutput[] predictions = new TOutput[maps.Length];
151-
var vBuffers = new VBuffer<Single>[maps.Length];
152-
while (cursor.MoveNext())
183+
Parallel.For(0, maps.Length, i =>
153184
{
154-
Parallel.For(0, maps.Length, i =>
185+
var model = models[i];
186+
if (model.SelectedFeatures != null)
155187
{
156-
var model = models[i];
157-
if (model.SelectedFeatures != null)
158-
{
159-
EnsembleUtils.SelectFeatures(in cursor.Features, model.SelectedFeatures, model.Cardinality, ref vBuffers[i]);
160-
maps[i](in vBuffers[i], ref predictions[i]);
161-
}
162-
else
163-
maps[i](in cursor.Features, ref predictions[i]);
164-
});
165-
166-
Utils.EnsureSize(ref labels, count + 1);
167-
Utils.EnsureSize(ref features, count + 1);
168-
labels[count] = cursor.Label;
169-
FillFeatureBuffer(predictions, ref features[count]);
170-
count++;
171-
}
188+
EnsembleUtils.SelectFeatures(in cursor.Features, model.SelectedFeatures, model.Cardinality, ref vBuffers[i]);
189+
maps[i](in vBuffers[i], ref predictions[i]);
190+
}
191+
else
192+
maps[i](in cursor.Features, ref predictions[i]);
193+
});
194+
195+
Utils.EnsureSize(ref labels, count + 1);
196+
Utils.EnsureSize(ref features, count + 1);
197+
labels[count] = labelConvert(cursor.Label);
198+
FillFeatureBuffer(predictions, ref features[count]);
199+
count++;
172200
}
201+
}
173202

174-
ch.Info("The number of instances used for stacking trainer is {0}", count);
175-
176-
var bldr = new ArrayDataViewBuilder(host);
177-
Array.Resize(ref labels, count);
178-
Array.Resize(ref features, count);
179-
bldr.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, labels);
180-
bldr.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single, features);
203+
ch.Info("The number of instances used for stacking trainer is {0}", count);
181204

182-
var view = bldr.GetDataView();
183-
var rmd = new RoleMappedData(view, DefaultColumnNames.Label, DefaultColumnNames.Features);
205+
var bldr = new ArrayDataViewBuilder(env);
206+
Array.Resize(ref labels, count);
207+
Array.Resize(ref features, count);
208+
bldr.AddColumn(DefaultColumnNames.Label, data.Schema.Label.Value.Type as PrimitiveDataViewType, labels);
209+
bldr.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single, features);
184210

185-
var trainer = BasePredictorType.CreateComponent(host);
186-
if (trainer.Info.NeedNormalization)
187-
ch.Warning("The trainer specified for stacking wants normalization, but we do not currently allow this.");
188-
Meta = trainer.Train(rmd);
189-
CheckMeta();
190-
}
211+
return bldr.GetDataView();
191212
}
192213
}
193214
}

src/Microsoft.ML.Ensemble/OutputCombiners/MultiStacking.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
namespace Microsoft.ML.Trainers.Ensemble
2121
{
22-
using TVectorPredictor = IPredictorProducing<VBuffer<Single>>;
22+
using TVectorTrainer = ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<VBuffer<float>>>, IPredictorProducing<VBuffer<float>>>;
23+
2324
internal sealed class MultiStacking : BaseStacking<VBuffer<Single>>, IMulticlassOutputCombiner
2425
{
2526
public const string LoadName = "MultiStacking";
@@ -44,9 +45,9 @@ public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerF
4445
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
4546
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureMulticlassClassifierTrainer))]
4647
[TGUI(Label = "Base predictor")]
47-
public IComponentFactory<ITrainer<TVectorPredictor>> BasePredictorType;
48+
public IComponentFactory<TVectorTrainer> BasePredictorType;
4849

49-
internal override IComponentFactory<ITrainer<TVectorPredictor>> GetPredictorFactory() => BasePredictorType;
50+
internal override IComponentFactory<TVectorTrainer> GetPredictorFactory() => BasePredictorType;
5051

5152
public IMulticlassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiStacking(env, this);
5253
}

src/Microsoft.ML.Ensemble/OutputCombiners/RegressionStacking.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
namespace Microsoft.ML.Trainers.Ensemble
1919
{
20-
using TScalarPredictor = IPredictorProducing<Single>;
20+
using TScalarTrainer = ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;
2121

2222
internal sealed class RegressionStacking : BaseScalarStacking, IRegressionOutputCombiner
2323
{
@@ -43,9 +43,9 @@ public sealed class Arguments : ArgumentsBase, ISupportRegressionOutputCombinerF
4343
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
4444
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureRegressorTrainer))]
4545
[TGUI(Label = "Base predictor")]
46-
public IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorType;
46+
public IComponentFactory<TScalarTrainer> BasePredictorType;
4747

48-
internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType;
48+
internal override IComponentFactory<TScalarTrainer> GetPredictorFactory() => BasePredictorType;
4949

5050
public IRegressionOutputCombiner CreateComponent(IHostEnvironment env) => new RegressionStacking(env, this);
5151
}

src/Microsoft.ML.Ensemble/OutputCombiners/Stacking.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
namespace Microsoft.ML.Trainers.Ensemble
1717
{
18-
using TScalarPredictor = IPredictorProducing<Single>;
18+
using TScalarTrainer = ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;
19+
1920
internal sealed class Stacking : BaseScalarStacking, IBinaryOutputCombiner
2021
{
2122
public const string UserName = "Stacking";
@@ -41,9 +42,9 @@ public sealed class Arguments : ArgumentsBase, ISupportBinaryOutputCombinerFacto
4142
[Argument(ArgumentType.Multiple, HelpText = "Base predictor for meta learning", ShortName = "bp", SortOrder = 50,
4243
Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
4344
[TGUI(Label = "Base predictor")]
44-
public IComponentFactory<ITrainer<TScalarPredictor>> BasePredictorType;
45+
public IComponentFactory<TScalarTrainer> BasePredictorType;
4546

46-
internal override IComponentFactory<ITrainer<TScalarPredictor>> GetPredictorFactory() => BasePredictorType;
47+
internal override IComponentFactory<TScalarTrainer> GetPredictorFactory() => BasePredictorType;
4748

4849
public IBinaryOutputCombiner CreateComponent(IHostEnvironment env) => new Stacking(env, this);
4950
}

src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ namespace Microsoft.ML.Trainers.Ensemble
2323
{
2424
using TDistPredictor = IDistPredictorProducing<Single, Single>;
2525
using TScalarPredictor = IPredictorProducing<Single>;
26+
using TScalarTrainer = ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<float>>, IPredictorProducing<float>>;
27+
2628
/// <summary>
2729
/// A generic ensemble trainer for binary classification.
2830
/// </summary>
29-
internal sealed class EnsembleTrainer : EnsembleTrainerBase<Single, TScalarPredictor,
31+
internal sealed class EnsembleTrainer : EnsembleTrainerBase<Single,
3032
IBinarySubModelSelector, IBinaryOutputCombiner>,
3133
IModelCombiner
3234
{
@@ -47,20 +49,15 @@ public sealed class Arguments : ArgumentsBase
4749

4850
// REVIEW: If we make this public again it should be an *estimator* of this type of predictor, rather than the (deprecated) ITrainer.
4951
[Argument(ArgumentType.Multiple, HelpText = "Base predictor type", ShortName = "bp,basePredictorTypes", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
50-
public IComponentFactory<ITrainer<TScalarPredictor>>[] BasePredictors;
52+
public IComponentFactory<TScalarTrainer>[] BasePredictors;
5153

52-
internal override IComponentFactory<ITrainer<TScalarPredictor>>[] GetPredictorFactories() => BasePredictors;
54+
internal override IComponentFactory<TScalarTrainer>[] GetPredictorFactories() => BasePredictors;
5355

5456
public Arguments()
5557
{
5658
BasePredictors = new[]
5759
{
58-
ComponentFactoryUtils.CreateFromFunction(
59-
env => {
60-
var trainerEstimator = new LinearSvmTrainer(env);
61-
return TrainerUtils.MapTrainerEstimatorToTrainer<LinearSvmTrainer,
62-
LinearBinaryModelParameters, LinearBinaryModelParameters>(env, trainerEstimator);
63-
})
60+
ComponentFactoryUtils.CreateFromFunction(env => new LinearSvmTrainer(env, LabelColumnName, FeatureColumnName))
6461
};
6562
}
6663
}
@@ -83,7 +80,7 @@ private EnsembleTrainer(IHostEnvironment env, Arguments args, PredictionKind pre
8380

8481
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
8582

86-
private protected override TScalarPredictor CreatePredictor(List<FeatureSubsetModel<float>> models)
83+
private protected override IPredictor CreatePredictor(List<FeatureSubsetModel<float>> models)
8784
{
8885
if (models.All(m => m.Predictor is TDistPredictor))
8986
return new EnsembleDistributionModelParameters(Host, PredictionKind, CreateModels<TDistPredictor>(models), Combiner);

src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ namespace Microsoft.ML.Trainers.Ensemble
1717
{
1818
using Stopwatch = System.Diagnostics.Stopwatch;
1919

20-
internal abstract class EnsembleTrainerBase<TOutput, TPredictor, TSelector, TCombiner> : ITrainer<TPredictor>
21-
where TPredictor : class, IPredictorProducing<TOutput>
20+
internal abstract class EnsembleTrainerBase<TOutput, TSelector, TCombiner> : ITrainer<IPredictor>
2221
where TSelector : class, ISubModelSelector<TOutput>
2322
where TCombiner : class, IOutputCombiner<TOutput>
2423
{
@@ -51,7 +50,7 @@ public abstract class ArgumentsBase : TrainerInputBaseWithLabel
5150
[TGUI(Label = "Show Sub-Model Metrics")]
5251
public bool ShowMetrics;
5352

54-
internal abstract IComponentFactory<ITrainer<IPredictorProducing<TOutput>>>[] GetPredictorFactories();
53+
internal abstract IComponentFactory<ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<TOutput>>, IPredictorProducing<TOutput>>>[] GetPredictorFactories();
5554
#pragma warning restore CS0649
5655
}
5756

@@ -62,7 +61,7 @@ public abstract class ArgumentsBase : TrainerInputBaseWithLabel
6261
private protected readonly IHost Host;
6362

6463
/// <summary> Ensemble members </summary>
65-
private protected readonly ITrainer<IPredictorProducing<TOutput>>[] Trainers;
64+
private protected readonly ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<TOutput>>, IPredictorProducing<TOutput>>[] Trainers;
6665

6766
private readonly ISubsetSelector _subsetSelector;
6867
private protected ISubModelSelector<TOutput> SubModelSelector;
@@ -95,7 +94,7 @@ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env,
9594

9695
_subsetSelector = Args.SamplingType.CreateComponent(Host);
9796

98-
Trainers = new ITrainer<IPredictorProducing<TOutput>>[NumModels];
97+
Trainers = new ITrainerEstimator<ISingleFeaturePredictionTransformer<IPredictorProducing<TOutput>>, IPredictorProducing<TOutput>>[NumModels];
9998
for (int i = 0; i < Trainers.Length; i++)
10099
Trainers[i] = predictorFactories[i % predictorFactories.Length].CreateComponent(Host);
101100
// We infer normalization and calibration preferences from the trainers. However, even if the internal trainers
@@ -106,7 +105,7 @@ private protected EnsembleTrainerBase(ArgumentsBase args, IHostEnvironment env,
106105
}
107106
}
108107

109-
TPredictor ITrainer<TPredictor>.Train(TrainContext context)
108+
IPredictor ITrainer<IPredictor>.Train(TrainContext context)
110109
{
111110
Host.CheckValue(context, nameof(context));
112111

@@ -117,9 +116,9 @@ TPredictor ITrainer<TPredictor>.Train(TrainContext context)
117116
}
118117

119118
IPredictor ITrainer.Train(TrainContext context)
120-
=> ((ITrainer<TPredictor>)this).Train(context);
119+
=> ((ITrainer<IPredictor>)this).Train(context);
121120

122-
private TPredictor TrainCore(IChannel ch, RoleMappedData data)
121+
private IPredictor TrainCore(IChannel ch, RoleMappedData data)
123122
{
124123
Host.AssertValue(ch);
125124
ch.AssertValue(data);
@@ -155,8 +154,9 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data)
155154
{
156155
if (EnsureMinimumFeaturesSelected(subset))
157156
{
157+
// REVIEW: How to pass the role mappings to the trainer?
158158
var model = new FeatureSubsetModel<TOutput>(
159-
Trainers[(int)index].Train(subset.Data),
159+
Trainers[(int)index].Fit(subset.Data.Data).Model,
160160
subset.SelectedFeatures,
161161
null);
162162
SubModelSelector.CalculateMetrics(model, _subsetSelector, subset, batch, needMetrics);
@@ -190,7 +190,7 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data)
190190
return CreatePredictor(models);
191191
}
192192

193-
private protected abstract TPredictor CreatePredictor(List<FeatureSubsetModel<TOutput>> models);
193+
private protected abstract IPredictor CreatePredictor(List<FeatureSubsetModel<TOutput>> models);
194194

195195
private bool EnsureMinimumFeaturesSelected(Subset subset)
196196
{

0 commit comments

Comments
 (0)