Skip to content

Polish early stop rules in fast tree #2851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/Microsoft.ML.FastTree/BoostingFastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ private protected override void CheckOptions(IChannel ch)
if (FastTreeTrainerOptions.EnablePruning && !HasValidSet)
throw ch.Except("Cannot perform pruning (pruning) without a validation set (valid).");

if (FastTreeTrainerOptions.EarlyStoppingRule != null && !HasValidSet)
bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null;
if (doEarlyStop && !HasValidSet)
throw ch.Except("Cannot perform early stopping without a validation set (valid).");

if (FastTreeTrainerOptions.UseTolerantPruning && (!FastTreeTrainerOptions.EnablePruning || !HasValidSet))
Expand Down Expand Up @@ -113,9 +114,9 @@ private protected override IGradientAdjuster MakeGradientWrapper(IChannel ch)
return new BestStepRegressionGradientWrapper();
}

private protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStoppingRule, ref int bestIteration)
private protected override bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBase earlyStoppingRule, ref int bestIteration)
{
if (FastTreeTrainerOptions.EarlyStoppingRule == null)
if (FastTreeTrainerOptions.EarlyStoppingRuleFactory == null)
return false;

ch.AssertValue(ValidTest);
Expand All @@ -128,13 +129,16 @@ private protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriter
var trainingResult = TrainTest.ComputeTests().First();
ch.Assert(trainingResult.FinalValue >= 0);

// Create early stopping rule.
// Create early stopping rule if it's null.
if (earlyStoppingRule == null)
{
earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRule.CreateComponent(Host, lowerIsBetter);
ch.Assert(earlyStoppingRule != null);
if (FastTreeTrainerOptions.EarlyStoppingRuleFactory != null)
earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRuleFactory.CreateComponent(Host, lowerIsBetter);
}

// Early stopping rule cannot be null!
ch.Assert(earlyStoppingRule != null);

bool isBestCandidate;
bool shouldStop = earlyStoppingRule.CheckScore((float)validationResult.FinalValue,
(float)trainingResult.FinalValue, out isBestCandidate);
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ private protected void TrainCore(IChannel ch)
}
}

private protected virtual bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStopping, ref int bestIteration)
private protected virtual bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBase earlyStopping, ref int bestIteration)
{
bestIteration = Ensemble.NumTrees;
return false;
Expand Down Expand Up @@ -650,7 +650,7 @@ private protected virtual void Train(IChannel ch)
#endif
#endif

IEarlyStoppingCriterion earlyStoppingRule = null;
EarlyStoppingRuleBase earlyStoppingRule = null;
int bestIteration = 0;
int emptyTrees = 0;
using (var pch = Host.StartProgressChannel("FastTree training"))
Expand Down
24 changes: 22 additions & 2 deletions src/Microsoft.ML.FastTree/FastTreeArguments.cs
Original file line number Diff line number Diff line change
Expand Up @@ -621,9 +621,29 @@ public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDesc
/// <summary>
/// Early stopping rule. (Validation set (/valid) is required).
/// </summary>
[Argument(ArgumentType.Multiple, HelpText = "Early stopping rule. (Validation set (/valid) is required.)", ShortName = "esr", NullName = "<Disable>")]
[BestFriend]
[Argument(ArgumentType.Multiple, HelpText = "Early stopping rule. (Validation set (/valid) is required.)", Name = "EarlyStoppingRule", ShortName = "esr", NullName = "<Disable>")]
[TGUI(Label = "Early Stopping Rule", Description = "Early stopping rule. (Validation set (/valid) is required.)")]
public IEarlyStoppingCriterionFactory EarlyStoppingRule;
internal IEarlyStoppingCriterionFactory EarlyStoppingRuleFactory;

/// <summary>
/// The underlying state of <see cref="EarlyStoppingRuleFactory"/> and <see cref="EarlyStoppingRule"/>.
/// </summary>
private EarlyStoppingRuleBase _earlyStoppingRuleBase;

/// <summary>
/// Early stopping rule used to terminate training process once meeting a specified criterion. Possible choices are
/// <see cref="EarlyStoppingRuleBase"/>'s implementations such as <see cref="TolerantEarlyStoppingRule"/> and <see cref="GeneralityLossRule"/>.
/// </summary>
public EarlyStoppingRuleBase EarlyStoppingRule
{
get { return _earlyStoppingRuleBase; }
set
{
_earlyStoppingRuleBase = value;
EarlyStoppingRuleFactory = _earlyStoppingRuleBase.BuildFactory();
}
}

/// <summary>
/// Early stopping metrics. (For regression, 1: L1, 2:L2; for ranking, 1:NDCG@1, 3:NDCG@3).
Expand Down
8 changes: 6 additions & 2 deletions src/Microsoft.ML.FastTree/FastTreeRanking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,12 @@ private protected override void CheckOptions(IChannel ch)
Dataset.DatasetSkeleton.LabelGainMap = gains;
}

ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
"earlyStoppingMetrics should be 1 or 3.");
bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null ||
FastTreeTrainerOptions.EnablePruning;

if (doEarlyStop)
ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3,
nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 or 3.");

base.CheckOptions(ch);
}
Expand Down
8 changes: 6 additions & 2 deletions src/Microsoft.ML.FastTree/FastTreeRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,12 @@ private protected override void CheckOptions(IChannel ch)

base.CheckOptions(ch);

ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
"earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)");
bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null ||
FastTreeTrainerOptions.EnablePruning;

if (doEarlyStop)
ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2,
nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)");
}

private static SchemaShape.Column MakeLabelColumn(string labelColumn)
Expand Down
10 changes: 7 additions & 3 deletions src/Microsoft.ML.FastTree/FastTreeTweedie.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,16 @@ private protected override void CheckOptions(IChannel ch)
// REVIEW: In order to properly support early stopping, the early stopping metric should be a subcomponent, not just
// a simple integer, because the metric that we might want is parameterized by this floating point "index" parameter. For now
// we just leave the existing regression checks, though with a warning.

if (FastTreeTrainerOptions.EarlyStoppingMetrics > 0)
ch.Warning("For Tweedie regression, early stopping does not yet use the Tweedie distribution.");

ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
"earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)");
bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null ||
FastTreeTrainerOptions.EnablePruning;

// Please do not remove it! See comment above.
if (doEarlyStop)
ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 2,
nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 (L1-norm) or 2 (L2-norm).");
}

private protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
Expand Down
Loading