Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Early stopping feature in ImageClassification (WIP) #4237

Merged
merged 10 commits into from
Oct 2, 2019
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public static void Example()
epoch: 50,
batchSize: 10,
learningRate: 0.01f,
enableEarlyStopping: true,
Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will prefer you create a new sample for early stopping. #Resolved

metricsCallback: (metrics) => Console.WriteLine(metrics),
validationSet: testDataset);

Expand Down
9 changes: 9 additions & 0 deletions src/Microsoft.ML.Dnn/DnnCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ public static DnnRetrainEstimator RetrainDnnModel(
/// <param name="epoch">Number of training iterations. Each iteration/epoch refers to one pass over the dataset.</param>
/// <param name="batchSize">The batch size for training.</param>
/// <param name="learningRate">The learning rate for training.</param>
/// <param name="enableEarlyStopping">Whether early stopping technique should be used when accuracy stops improving.</param>
/// <param name="earlyStoppingminDelta">Minimum change in accuracy to qualify as improvement.</param>
/// <param name="earlyStoppingPatience">Number of epochs to wait after no improvement is observed before early stopping.</param>
Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this option a class. #Resolved

/// <param name="metricsCallback">Callback for reporting model statistics during training phase.</param>
/// <param name="statisticFrequency">Indicates the frequency of epochs at which to report model statistics during training phase.</param>
/// <param name="framework">Indicates the choice of DNN training framework. Currently only tensorflow is supported.</param>
Expand All @@ -113,6 +116,9 @@ public static ImageClassificationEstimator ImageClassification(
int epoch = 100,
int batchSize = 10,
float learningRate = 0.01f,
bool enableEarlyStopping = true,
Copy link
Member

@eerhardt eerhardt Sep 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need a enableEarlyStopping parameter? What if we instead used nullable earlyStoppingminDelta and earlyStoppingPatience parameters, whose default value is null. If the user doesn't supply those values, then early stopping isn't enabled. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to @Zeeshan Siddiqui , we do not want the user to have to set any of these parameters, and want the default values to work well in most of the cases. If the default is set to null, the users would have to appropirately set these values to make use of this feature.


In reply to: 327289504 [](ancestors = 327289504)

Copy link
Member

@eerhardt eerhardt Sep 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced this is the best API to enable this. What happens if we want to enable a different stopping criteria in the future?

It feels like we should consider a different API to enable this. Check out the EarlyStoppingCriteria in FastTree. That seems like more of an extensible/future-proof API.

/// <summary>
/// Early stopping rule used to terminate training process once meeting a specified criterion.
/// Used for setting <see cref="EarlyStoppingRule"/> <see cref="BoostedTreeOptions.EarlyStoppingRule"/>.
/// </summary>
public abstract class EarlyStoppingRuleBase
#Resolved

Copy link
Member

@codemzs codemzs Sep 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eerhardt @ashbhandare We spoke about this on Friday and I suggested we have a class object that defines early stopping criteria, this class should extend an interface that defines bool ShouldStop(...). The API parameter should be a reference to this interface EarlyStoping and if it is set to null then we don't apply early stopping but by default it can be set to XYXEarlyStoping ... #Resolved

float earlyStoppingminDelta = 0.01f,
ashbhandare marked this conversation as resolved.
Show resolved Hide resolved
int earlyStoppingPatience = 20,
ImageClassificationMetricsCallback metricsCallback = null,
int statisticFrequency = 1,
DnnFramework framework = DnnFramework.Tensorflow,
Expand All @@ -136,6 +142,9 @@ public static ImageClassificationEstimator ImageClassification(
Epoch = epoch,
LearningRate = learningRate,
BatchSize = batchSize,
EnableEarlyStopping = enableEarlyStopping,
MinDelta = earlyStoppingminDelta,
Patience = earlyStoppingPatience,
Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are taking this technique from https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/python/keras/callbacks.py#L1143-L1260

were you also planning to add the "modes", i.e min, max, auto?

mode: One of {"auto", "min", "max"}. In min mode,
training will stop when the quantity
monitored has stopped decreasing; in max
mode it will stop when the quantity
monitored has stopped increasing; in auto
mode, the direction is automatically inferred
from the name of the monitored quantity.

I think we should add this. #Resolved

ScoreColumnName = scoreColumnName,
PredictedLabelColumnName = predictedLabelColumnName,
FinalModelPrefix = finalModelPrefix,
Expand Down
44 changes: 43 additions & 1 deletion src/Microsoft.ML.Dnn/ImageClassificationTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,13 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,

ImageClassificationMetrics metrics = new ImageClassificationMetrics();
metrics.Train = new TrainMetrics();
for (int epoch = 0; epoch < epochs; epoch += 1)
//Early stopping variables
bool earlyStop = false;
int wait = 0;
var history = new TrainMetrics();
history.Accuracy = 0;
Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why create the a new class when you just want a variable to store accuracy? just create a float variable "lastSeenAccuracy" #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addresssed here: #4237 (comment)


In reply to: 328229850 [](ancestors = 328229850)


Copy link
Contributor

@bpstark bpstark Sep 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you using a TrainMetrics object when you only care about the history of the accuracy. you can simply use a float here instead. #Resolved

Copy link
Contributor Author

@ashbhandare ashbhandare Sep 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was with a potential of possibly using other train metrics as well for the criteria for early stopping. I will refactor the code and this will change. #Resolved

for (int epoch = 0; epoch < epochs & !earlyStop; epoch += 1)
ashbhandare marked this conversation as resolved.
Show resolved Hide resolved
{
metrics.Train.Accuracy = 0;
metrics.Train.CrossEntropy = 0;
Expand Down Expand Up @@ -445,6 +451,24 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
statisticsCallback(metrics);
}
}
// early stopping check
ashbhandare marked this conversation as resolved.
Show resolved Hide resolved
if (options.EnableEarlyStopping)
{
if (metrics.Train.Accuracy - options.MinDelta > history.Accuracy)
{
history.Accuracy = metrics.Train.Accuracy;
wait = 0;
}
Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect. It needs to be absolute change as documented here: https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/python/keras/callbacks.py#L1143-L1260

min_delta: Minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no
improvement.

Looking at the code:

if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less

if self.monitor_op == np.greater:
  self.min_delta *= 1
else:
  self.min_delta *= -1

The last 4 lines above change the sign of the delta and that takes care of absolute difference in below function:

def on_epoch_end(self, epoch, logs=None):
current = self.get_monitor_value(logs)
if current is None:
return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
if self.restore_best_weights:
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
if self.restore_best_weights:
if self.verbose > 0:
print('Restoring model weights from the end of the best epoch.')
self.model.set_weights(self.best_weights)

#Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed offline, the change in sign of the min_delta doesn't handle taking the absolute value. Even though it is mentioned in the comment that they take absolute value of the change, it is not implemented that way in the code. However, we want to consider absolute value and I will make that change.


In reply to: 328248944 [](ancestors = 328248944)

else
{
wait += 1;
if (wait >= options.Patience)
{
Console.WriteLine("*** Early Stopping at epoch " + epoch.ToString());
Copy link
Member

@eerhardt eerhardt Sep 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't Console.Write inside of library code. #Resolved

Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Console.WriteLine("*** Early Stopping at epoch " + epoch.ToString()); [](start = 28, length = 69)

Use message channels for logging. #Resolved

earlyStop = true;
}
}
Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not put a PR without a unit-test, if you do, please mark it as Draft PR or WIP. #Resolved

}
Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add comments here that document this technique and also add relevant links #Resolved

Copy link
Member

@codemzs codemzs Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just break out? it will save you a variable #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


In reply to: 330316698 [](ancestors = 330316698)

}

trainSaver.save(_session, _checkpointPath);
Expand Down Expand Up @@ -1065,6 +1089,24 @@ internal sealed class Options : TransformInputBase
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate to use during optimization.", SortOrder = 12)]
public float LearningRate = 0.01f;

/// <summary>
/// Whether early stopping technique should be used when accuracy stops improving.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether early stopping technique should be used when accuracy stops improving.", SortOrder = 15)]
public bool EnableEarlyStopping = true;
ashbhandare marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Minimum change in accuracy to qualify as improvement.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Minimum change in accuracy to qualify as improvement.", SortOrder = 15)]
public float MinDelta = 0.0f;

/// <summary>
/// Number of epochs to wait after no improvement is observed before early stopping.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of epochs to wait after no improvement is observed before early stopping.", SortOrder = 15)]
public int Patience = 20;

/// <summary>
/// Specifies the model architecture to be used in the case of image classification training using transfer learning.
/// </summary>
Expand Down