-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Early stopping feature in ImageClassification (WIP) #4237
Changes from 2 commits
7432ba7
2b7e214
baab5de
4ebcb40
899d264
4d87809
48e1729
e8d4de3
4a30441
bf4f22d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -89,6 +89,9 @@ public static DnnRetrainEstimator RetrainDnnModel( | |||||||||||
/// <param name="epoch">Number of training iterations. Each iteration/epoch refers to one pass over the dataset.</param> | ||||||||||||
/// <param name="batchSize">The batch size for training.</param> | ||||||||||||
/// <param name="learningRate">The learning rate for training.</param> | ||||||||||||
/// <param name="enableEarlyStopping">Whether early stopping technique should be used when accuracy stops improving.</param> | ||||||||||||
/// <param name="earlyStoppingminDelta">Minimum change in accuracy to qualify as improvement.</param> | ||||||||||||
/// <param name="earlyStoppingPatience">Number of epochs to wait after no improvement is observed before early stopping.</param> | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make this option a class. #Resolved |
||||||||||||
/// <param name="metricsCallback">Callback for reporting model statistics during training phase.</param> | ||||||||||||
/// <param name="statisticFrequency">Indicates the frequency of epochs at which to report model statistics during training phase.</param> | ||||||||||||
/// <param name="framework">Indicates the choice of DNN training framework. Currently only tensorflow is supported.</param> | ||||||||||||
|
@@ -113,6 +116,9 @@ public static ImageClassificationEstimator ImageClassification( | |||||||||||
int epoch = 100, | ||||||||||||
int batchSize = 10, | ||||||||||||
float learningRate = 0.01f, | ||||||||||||
bool enableEarlyStopping = true, | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really need a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to @Zeeshan Siddiqui , we do not want the user to have to set any of these parameters, and want the default values to work well in most of the cases. If the default is set to null, the users would have to appropirately set these values to make use of this feature. In reply to: 327289504 [](ancestors = 327289504) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not convinced this is the best API to enable this. What happens if we want to enable a different stopping criteria in the future? It feels like we should consider a different API to enable this. Check out the EarlyStoppingCriteria in FastTree. That seems like more of an extensible/future-proof API. machinelearning/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs Lines 34 to 38 in 2942ca4
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @eerhardt @ashbhandare We spoke about this on Friday and I suggested we have a class object that defines early stopping criteria, this class should extend an interface that defines bool ShouldStop(...). The API parameter should be a reference to this interface EarlyStoping and if it is set to null then we don't apply early stopping but by default it can be set to XYXEarlyStoping ... #Resolved |
||||||||||||
float earlyStoppingminDelta = 0.01f, | ||||||||||||
ashbhandare marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
int earlyStoppingPatience = 20, | ||||||||||||
ImageClassificationMetricsCallback metricsCallback = null, | ||||||||||||
int statisticFrequency = 1, | ||||||||||||
DnnFramework framework = DnnFramework.Tensorflow, | ||||||||||||
|
@@ -136,6 +142,9 @@ public static ImageClassificationEstimator ImageClassification( | |||||||||||
Epoch = epoch, | ||||||||||||
LearningRate = learningRate, | ||||||||||||
BatchSize = batchSize, | ||||||||||||
EnableEarlyStopping = enableEarlyStopping, | ||||||||||||
MinDelta = earlyStoppingminDelta, | ||||||||||||
Patience = earlyStoppingPatience, | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you are taking this technique from https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/python/keras/callbacks.py#L1143-L1260 were you also planning to add the "modes", i.e min, max, auto? mode: One of I think we should add this. #Resolved |
||||||||||||
ScoreColumnName = scoreColumnName, | ||||||||||||
PredictedLabelColumnName = predictedLabelColumnName, | ||||||||||||
FinalModelPrefix = finalModelPrefix, | ||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -338,7 +338,13 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath, | |
|
||
ImageClassificationMetrics metrics = new ImageClassificationMetrics(); | ||
metrics.Train = new TrainMetrics(); | ||
for (int epoch = 0; epoch < epochs; epoch += 1) | ||
//Early stopping variables | ||
bool earlyStop = false; | ||
int wait = 0; | ||
var history = new TrainMetrics(); | ||
history.Accuracy = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why create the a new class when you just want a variable to store accuracy? just create a float variable "lastSeenAccuracy" #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are you using a TrainMetrics object when you only care about the history of the accuracy. you can simply use a float here instead. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was with a potential of possibly using other train metrics as well for the criteria for early stopping. I will refactor the code and this will change. #Resolved |
||
for (int epoch = 0; epoch < epochs & !earlyStop; epoch += 1) | ||
ashbhandare marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
metrics.Train.Accuracy = 0; | ||
metrics.Train.CrossEntropy = 0; | ||
|
@@ -445,6 +451,24 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath, | |
statisticsCallback(metrics); | ||
} | ||
} | ||
// early stopping check | ||
ashbhandare marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (options.EnableEarlyStopping) | ||
{ | ||
if (metrics.Train.Accuracy - options.MinDelta > history.Accuracy) | ||
{ | ||
history.Accuracy = metrics.Train.Accuracy; | ||
wait = 0; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is incorrect. It needs to be absolute change as documented here: https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/python/keras/callbacks.py#L1143-L1260 min_delta: Minimum change in the monitored quantity Looking at the code: if mode == 'min':
The last 4 lines above change the sign of the delta and that takes care of absolute difference in below function: def on_epoch_end(self, epoch, logs=None): #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we discussed offline, the change in sign of the min_delta doesn't handle taking the absolute value. Even though it is mentioned in the comment that they take absolute value of the change, it is not implemented that way in the code. However, we want to consider absolute value and I will make that change. In reply to: 328248944 [](ancestors = 328248944) |
||
else | ||
{ | ||
wait += 1; | ||
if (wait >= options.Patience) | ||
{ | ||
Console.WriteLine("*** Early Stopping at epoch " + epoch.ToString()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please don't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Use message channels for logging. #Resolved |
||
earlyStop = true; | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please do not put a PR without a unit-test, if you do, please mark it as Draft PR or WIP. #Resolved |
||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add comments here that document this technique and also add relevant links #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not just break out? it will save you a variable #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
} | ||
|
||
trainSaver.save(_session, _checkpointPath); | ||
|
@@ -1065,6 +1089,24 @@ internal sealed class Options : TransformInputBase | |
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate to use during optimization.", SortOrder = 12)] | ||
public float LearningRate = 0.01f; | ||
|
||
/// <summary> | ||
/// Whether early stopping technique should be used when accuracy stops improving. | ||
/// </summary> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether early stopping technique should be used when accuracy stops improving.", SortOrder = 15)] | ||
public bool EnableEarlyStopping = true; | ||
ashbhandare marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
/// <summary> | ||
/// Minimum change in accuracy to qualify as improvement. | ||
/// </summary> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Minimum change in accuracy to qualify as improvement.", SortOrder = 15)] | ||
public float MinDelta = 0.0f; | ||
|
||
/// <summary> | ||
/// Number of epochs to wait after no improvement is observed before early stopping. | ||
/// </summary> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of epochs to wait after no improvement is observed before early stopping.", SortOrder = 15)] | ||
public int Patience = 20; | ||
|
||
/// <summary> | ||
/// Specifies the model architecture to be used in the case of image classification training using transfer learning. | ||
/// </summary> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will prefer you create a new sample for early stopping. #Resolved