-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Early stopping feature in ImageClassification (WIP) #4237
Conversation
commit 1e990b16209f9d293dfb3111d152b0b8ff9c0fb6 Author: Aishwarya Bhandare <aibhanda@microsoft.com> Date: Fri Sep 20 15:27:11 2019 -0700 cleanup .gitignore commit 54ccaa6d79e420f4624bf1779053ed8709cb3dc9 Author: Aishwarya Bhandare <aibhanda@microsoft.com> Date: Fri Sep 20 15:25:51 2019 -0700 cleanup commit 93b966453895acc40468c7f4d339540e0c7729fb Author: Aishwarya Bhandare <aibhanda@microsoft.com> Date: Fri Sep 20 14:39:07 2019 -0700 initial support for eary stopping feature in ImageClassification.
fcd9a99
to
7432ba7
Compare
src/Microsoft.ML.Dnn/DnnCatalog.cs
Outdated
@@ -113,6 +116,9 @@ public static class DnnCatalog | |||
int epoch = 100, | |||
int batchSize = 10, | |||
float learningRate = 0.01f, | |||
bool enableEarlyStopping = true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need a enableEarlyStopping
parameter? What if we instead used nullable earlyStoppingminDelta
and earlyStoppingPatience
parameters, whose default value is null
. If the user doesn't supply those values, then early stopping isn't enabled. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to @Zeeshan Siddiqui , we do not want the user to have to set any of these parameters, and want the default values to work well in most of the cases. If the default is set to null, the users would have to appropirately set these values to make use of this feature.
In reply to: 327289504 [](ancestors = 327289504)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not convinced this is the best API to enable this. What happens if we want to enable a different stopping criteria in the future?
It feels like we should consider a different API to enable this. Check out the EarlyStoppingCriteria in FastTree. That seems like more of an extensible/future-proof API.
machinelearning/src/Microsoft.ML.FastTree/Training/EarlyStoppingCriteria.cs
Lines 34 to 38 in 2942ca4
/// <summary> | |
/// Early stopping rule used to terminate training process once meeting a specified criterion. | |
/// Used for setting <see cref="EarlyStoppingRule"/> <see cref="BoostedTreeOptions.EarlyStoppingRule"/>. | |
/// </summary> | |
public abstract class EarlyStoppingRuleBase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eerhardt @ashbhandare We spoke about this on Friday and I suggested we have a class object that defines early stopping criteria, this class should extend an interface that defines bool ShouldStop(...). The API parameter should be a reference to this interface EarlyStoping and if it is set to null then we don't apply early stopping but by default it can be set to XYXEarlyStoping ... #Resolved
wait += 1; | ||
if (wait >= options.Patience) | ||
{ | ||
Console.WriteLine("*** Early Stopping at epoch " + epoch.ToString()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't Console.Write
inside of library code. #Resolved
int wait = 0; | ||
var history = new TrainMetrics(); | ||
history.Accuracy = 0; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are you using a TrainMetrics object when you only care about the history of the accuracy. you can simply use a float here instead. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was with a potential of possibly using other train metrics as well for the criteria for early stopping. I will refactor the code and this will change. #Resolved
@@ -67,6 +67,7 @@ public static void Example() | |||
epoch: 50, | |||
batchSize: 10, | |||
learningRate: 0.01f, | |||
enableEarlyStopping: true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will prefer you create a new sample for early stopping. #Resolved
src/Microsoft.ML.Dnn/DnnCatalog.cs
Outdated
@@ -89,6 +89,9 @@ public static class DnnCatalog | |||
/// <param name="epoch">Number of training iterations. Each iteration/epoch refers to one pass over the dataset.</param> | |||
/// <param name="batchSize">The batch size for training.</param> | |||
/// <param name="learningRate">The learning rate for training.</param> | |||
/// <param name="enableEarlyStopping">Whether early stopping technique should be used when accuracy stops improving.</param> | |||
/// <param name="earlyStoppingminDelta">Minimum change in accuracy to qualify as improvement.</param> | |||
/// <param name="earlyStoppingPatience">Number of epochs to wait after no improvement is observed before early stopping.</param> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this option a class. #Resolved
earlyStop = true; | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add comments here that document this technique and also add relevant links #Resolved
wait += 1; | ||
if (wait >= options.Patience) | ||
{ | ||
Console.WriteLine("*** Early Stopping at epoch " + epoch.ToString()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Console.WriteLine("*** Early Stopping at epoch " + epoch.ToString()); [](start = 28, length = 69)
Use message channels for logging. #Resolved
bool earlyStop = false; | ||
int wait = 0; | ||
var history = new TrainMetrics(); | ||
history.Accuracy = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why create the a new class when you just want a variable to store accuracy? just create a float variable "lastSeenAccuracy" #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Console.WriteLine("*** Early Stopping at epoch " + epoch.ToString()); | ||
earlyStop = true; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please do not put a PR without a unit-test, if you do, please mark it as Draft PR or WIP. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🕐
src/Microsoft.ML.Dnn/DnnCatalog.cs
Outdated
@@ -136,6 +142,9 @@ public static class DnnCatalog | |||
Epoch = epoch, | |||
LearningRate = learningRate, | |||
BatchSize = batchSize, | |||
EnableEarlyStopping = enableEarlyStopping, | |||
MinDelta = earlyStoppingminDelta, | |||
Patience = earlyStoppingPatience, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you are taking this technique from https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/python/keras/callbacks.py#L1143-L1260
were you also planning to add the "modes", i.e min, max, auto?
mode: One of {"auto", "min", "max"}
. In min
mode,
training will stop when the quantity
monitored has stopped decreasing; in max
mode it will stop when the quantity
monitored has stopped increasing; in auto
mode, the direction is automatically inferred
from the name of the monitored quantity.
I think we should add this. #Resolved
{ | ||
history.Accuracy = metrics.Train.Accuracy; | ||
wait = 0; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incorrect. It needs to be absolute change as documented here: https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/python/keras/callbacks.py#L1143-L1260
min_delta: Minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no
improvement.
Looking at the code:
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
The last 4 lines above change the sign of the delta and that takes care of absolute difference in below function:
def on_epoch_end(self, epoch, logs=None):
current = self.get_monitor_value(logs)
if current is None:
return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
if self.restore_best_weights:
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
if self.restore_best_weights:
if self.verbose > 0:
print('Restoring model weights from the end of the best epoch.')
self.model.set_weights(self.best_weights)
#Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed offline, the change in sign of the min_delta doesn't handle taking the absolute value. Even though it is mentioned in the comment that they take absolute value of the change, it is not implemented that way in the code. However, we want to consider absolute value and I will make that change.
In reply to: 328248944 [](ancestors = 328248944)
Codecov Report
@@ Coverage Diff @@
## master #4237 +/- ##
=========================================
Coverage ? 74.56%
=========================================
Files ? 878
Lines ? 154012
Branches ? 16852
=========================================
Hits ? 114833
Misses ? 34446
Partials ? 4733
|
if (options.EarlyStopper != null) | ||
{ | ||
earlyStop = options.EarlyStopper.ShouldStop(metrics.Train); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just break out? it will save you a variable #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
/// <summary> | ||
/// Current number of epochs where there has been no improvement. | ||
/// Stop training when wait >=patience. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait [](start = 35, length = 4)
please use param ref to refer variables. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated description to not use variables. Couldn't use as these variables are not parameters to this particular variable(i guess).
In reply to: 330316879 [](ancestors = 330316879)
currentMetricValue = currentMetrics.Accuracy; | ||
else | ||
currentMetricValue = currentMetrics.CrossEntropy; | ||
if(CheckIncreasing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if(CheckIncreasing) [](start = 16, length = 19)
new line #Resolved
if (_metric == EarlyStoppingMetric.Accuracy) | ||
currentMetricValue = currentMetrics.Accuracy; | ||
else | ||
currentMetricValue = currentMetrics.CrossEntropy; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currentMetricValue = _metric == EarlyStoppingMetric.Accuracy ? currentMetrics.Accuracy : currentMetrics.CrossEntropy #Resolved
{ | ||
_wait += 1; | ||
if(_wait >= Patience) | ||
return (true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(true); [](start = 35, length = 7)
why the brackets? just "return true;" #Resolved
if((currentMetricValue- _bestMetricValue) < MinDelta) | ||
{ | ||
_wait += 1; | ||
if(_wait >= Patience) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if(_wait >= Patience) [](start = 24, length = 21)
Can _wait ever be greater than Patience? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since patience is an int, user might supply a negative value. in that case, it is better to check >= instead of ==.
In reply to: 330317534 [](ancestors = 330317534)
{ | ||
_wait += 1; | ||
if (_wait >= Patience) | ||
return (true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return (true); [](start = 28, length = 14)
return true #Resolved
_bestMetricValue = currentMetricValue; | ||
} | ||
} | ||
return (false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return (false); [](start = 15, length = 16)
return false #Resolved
/// Early Stopping technique to stop training when accuracy stops improving. | ||
/// </summary> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Early Stopping technique to stop training when accuracy stops improving.", SortOrder = 15)] | ||
public EarlyStopping EarlyStopper; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EarlyStopper [](start = 33, length = 12)
We generally refer this as EarlyStoppingCriteria #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
epoch: 50, | ||
batchSize: 5, | ||
learningRate: 0.01f, | ||
earlyStopping: new ImageClassificationEstimator.EarlyStopping(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
earlyStopping: new ImageClassificationEstimator.EarlyStopping(), [](start = 16, length = 64)
isn't this the case by default? How is this test different from the above test? May be in the above test disable early stopping so we get that case covered and here enable but also verify the epoch at which it stops via metrics callback #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
Show resolved
Hide resolved
test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixes #4236
Modeled after https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping