Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Early stopping feature in ImageClassification (WIP) #4237

Merged
merged 10 commits into from
Oct 2, 2019

Conversation

ashbhandare
Copy link
Contributor

@ashbhandare ashbhandare requested a review from a team as a code owner September 20, 2019 23:06
commit 1e990b16209f9d293dfb3111d152b0b8ff9c0fb6
Author: Aishwarya Bhandare <aibhanda@microsoft.com>
Date:   Fri Sep 20 15:27:11 2019 -0700

    cleanup .gitignore

commit 54ccaa6d79e420f4624bf1779053ed8709cb3dc9
Author: Aishwarya Bhandare <aibhanda@microsoft.com>
Date:   Fri Sep 20 15:25:51 2019 -0700

    cleanup

commit 93b966453895acc40468c7f4d339540e0c7729fb
Author: Aishwarya Bhandare <aibhanda@microsoft.com>
Date:   Fri Sep 20 14:39:07 2019 -0700

    initial support for eary stopping feature in ImageClassification.
src/Microsoft.ML.Dnn/ImageClassificationTransform.cs Outdated Show resolved Hide resolved
src/Microsoft.ML.Dnn/DnnCatalog.cs Outdated Show resolved Hide resolved
@@ -113,6 +116,9 @@ public static class DnnCatalog
int epoch = 100,
int batchSize = 10,
float learningRate = 0.01f,
bool enableEarlyStopping = true,
Copy link
Member

@eerhardt eerhardt Sep 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need a enableEarlyStopping parameter? What if we instead used nullable earlyStoppingminDelta and earlyStoppingPatience parameters, whose default value is null. If the user doesn't supply those values, then early stopping isn't enabled. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to @Zeeshan Siddiqui , we do not want the user to have to set any of these parameters, and want the default values to work well in most of the cases. If the default is set to null, the users would have to appropirately set these values to make use of this feature.


In reply to: 327289504 [](ancestors = 327289504)

Copy link
Member

@eerhardt eerhardt Sep 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced this is the best API to enable this. What happens if we want to enable a different stopping criteria in the future?

It feels like we should consider a different API to enable this. Check out the EarlyStoppingCriteria in FastTree. That seems like more of an extensible/future-proof API.

/// <summary>
/// Early stopping rule used to terminate training process once meeting a specified criterion.
/// Used for setting <see cref="EarlyStoppingRule"/> <see cref="BoostedTreeOptions.EarlyStoppingRule"/>.
/// </summary>
public abstract class EarlyStoppingRuleBase
#Resolved

Copy link
Member

@codemzs codemzs Sep 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eerhardt @ashbhandare We spoke about this on Friday and I suggested we have a class object that defines early stopping criteria, this class should extend an interface that defines bool ShouldStop(...). The API parameter should be a reference to this interface EarlyStoping and if it is set to null then we don't apply early stopping but by default it can be set to XYXEarlyStoping ... #Resolved

wait += 1;
if (wait >= options.Patience)
{
Console.WriteLine("*** Early Stopping at epoch " + epoch.ToString());
Copy link
Member

@eerhardt eerhardt Sep 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't Console.Write inside of library code. #Resolved

src/Microsoft.ML.Dnn/ImageClassificationTransform.cs Outdated Show resolved Hide resolved
int wait = 0;
var history = new TrainMetrics();
history.Accuracy = 0;

Copy link
Contributor

@bpstark bpstark Sep 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you using a TrainMetrics object when you only care about the history of the accuracy. you can simply use a float here instead. #Resolved

Copy link
Contributor Author

@ashbhandare ashbhandare Sep 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was with a potential of possibly using other train metrics as well for the criteria for early stopping. I will refactor the code and this will change. #Resolved

@@ -67,6 +67,7 @@ public static void Example()
epoch: 50,
batchSize: 10,
learningRate: 0.01f,
enableEarlyStopping: true,
Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will prefer you create a new sample for early stopping. #Resolved

@@ -89,6 +89,9 @@ public static class DnnCatalog
/// <param name="epoch">Number of training iterations. Each iteration/epoch refers to one pass over the dataset.</param>
/// <param name="batchSize">The batch size for training.</param>
/// <param name="learningRate">The learning rate for training.</param>
/// <param name="enableEarlyStopping">Whether early stopping technique should be used when accuracy stops improving.</param>
/// <param name="earlyStoppingminDelta">Minimum change in accuracy to qualify as improvement.</param>
/// <param name="earlyStoppingPatience">Number of epochs to wait after no improvement is observed before early stopping.</param>
Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this option a class. #Resolved

earlyStop = true;
}
}
}
Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add comments here that document this technique and also add relevant links #Resolved

wait += 1;
if (wait >= options.Patience)
{
Console.WriteLine("*** Early Stopping at epoch " + epoch.ToString());
Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Console.WriteLine("*** Early Stopping at epoch " + epoch.ToString()); [](start = 28, length = 69)

Use message channels for logging. #Resolved

bool earlyStop = false;
int wait = 0;
var history = new TrainMetrics();
history.Accuracy = 0;
Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why create the a new class when you just want a variable to store accuracy? just create a float variable "lastSeenAccuracy" #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addresssed here: #4237 (comment)


In reply to: 328229850 [](ancestors = 328229850)

Console.WriteLine("*** Early Stopping at epoch " + epoch.ToString());
earlyStop = true;
}
}
Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not put a PR without a unit-test, if you do, please mark it as Draft PR or WIP. #Resolved

Copy link
Member

@codemzs codemzs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕐

@@ -136,6 +142,9 @@ public static class DnnCatalog
Epoch = epoch,
LearningRate = learningRate,
BatchSize = batchSize,
EnableEarlyStopping = enableEarlyStopping,
MinDelta = earlyStoppingminDelta,
Patience = earlyStoppingPatience,
Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are taking this technique from https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/python/keras/callbacks.py#L1143-L1260

were you also planning to add the "modes", i.e min, max, auto?

mode: One of {"auto", "min", "max"}. In min mode,
training will stop when the quantity
monitored has stopped decreasing; in max
mode it will stop when the quantity
monitored has stopped increasing; in auto
mode, the direction is automatically inferred
from the name of the monitored quantity.

I think we should add this. #Resolved

{
history.Accuracy = metrics.Train.Accuracy;
wait = 0;
}
Copy link
Member

@codemzs codemzs Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect. It needs to be absolute change as documented here: https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/python/keras/callbacks.py#L1143-L1260

min_delta: Minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no
improvement.

Looking at the code:

if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less

if self.monitor_op == np.greater:
  self.min_delta *= 1
else:
  self.min_delta *= -1

The last 4 lines above change the sign of the delta and that takes care of absolute difference in below function:

def on_epoch_end(self, epoch, logs=None):
current = self.get_monitor_value(logs)
if current is None:
return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
if self.restore_best_weights:
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
if self.restore_best_weights:
if self.verbose > 0:
print('Restoring model weights from the end of the best epoch.')
self.model.set_weights(self.best_weights)

#Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed offline, the change in sign of the min_delta doesn't handle taking the absolute value. Even though it is mentioned in the comment that they take absolute value of the change, it is not implemented that way in the code. However, we want to consider absolute value and I will make that change.


In reply to: 328248944 [](ancestors = 328248944)

@ashbhandare ashbhandare changed the title Adding Early stopping feature in ImageClassification Adding Early stopping feature in ImageClassification (WIP) Sep 26, 2019
@codecov
Copy link

codecov bot commented Sep 30, 2019

Codecov Report

❗ No coverage uploaded for pull request base (master@d290881). Click here to learn what that means.
The diff coverage is 90.44%.

@@            Coverage Diff            @@
##             master    #4237   +/-   ##
=========================================
  Coverage          ?   74.56%           
=========================================
  Files             ?      878           
  Lines             ?   154012           
  Branches          ?    16852           
=========================================
  Hits              ?   114833           
  Misses            ?    34446           
  Partials          ?     4733
Flag Coverage Δ
#Debug 74.56% <90.44%> (?)
#production 70.15% <90.19%> (?)
#test 89.51% <90.56%> (?)
Impacted Files Coverage Δ
src/Microsoft.ML.Dnn/DnnCatalog.cs 78.66% <100%> (ø)
...c/Microsoft.ML.Dnn/ImageClassificationTransform.cs 86.25% <90%> (ø)
...cenariosWithDirectInstantiation/TensorflowTests.cs 89.96% <90.56%> (ø)

if (options.EarlyStopper != null)
{
earlyStop = options.EarlyStopper.ShouldStop(metrics.Train);
}
Copy link
Member

@codemzs codemzs Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just break out? it will save you a variable #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


In reply to: 330316698 [](ancestors = 330316698)


/// <summary>
/// Current number of epochs where there has been no improvement.
/// Stop training when wait >=patience.
Copy link
Member

@codemzs codemzs Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait [](start = 35, length = 4)

please use param ref to refer variables. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated description to not use variables. Couldn't use as these variables are not parameters to this particular variable(i guess).


In reply to: 330316879 [](ancestors = 330316879)

currentMetricValue = currentMetrics.Accuracy;
else
currentMetricValue = currentMetrics.CrossEntropy;
if(CheckIncreasing)
Copy link
Member

@codemzs codemzs Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if(CheckIncreasing) [](start = 16, length = 19)

new line #Resolved

if (_metric == EarlyStoppingMetric.Accuracy)
currentMetricValue = currentMetrics.Accuracy;
else
currentMetricValue = currentMetrics.CrossEntropy;
Copy link
Member

@codemzs codemzs Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currentMetricValue = _metric == EarlyStoppingMetric.Accuracy ? currentMetrics.Accuracy : currentMetrics.CrossEntropy #Resolved

{
_wait += 1;
if(_wait >= Patience)
return (true);
Copy link
Member

@codemzs codemzs Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(true); [](start = 35, length = 7)

why the brackets? just "return true;" #Resolved

if((currentMetricValue- _bestMetricValue) < MinDelta)
{
_wait += 1;
if(_wait >= Patience)
Copy link
Member

@codemzs codemzs Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if(_wait >= Patience) [](start = 24, length = 21)

Can _wait ever be greater than Patience? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since patience is an int, user might supply a negative value. in that case, it is better to check >= instead of ==.


In reply to: 330317534 [](ancestors = 330317534)

{
_wait += 1;
if (_wait >= Patience)
return (true);
Copy link
Member

@codemzs codemzs Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return (true); [](start = 28, length = 14)

return true #Resolved

_bestMetricValue = currentMetricValue;
}
}
return (false);
Copy link
Member

@codemzs codemzs Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return (false); [](start = 15, length = 16)

return false #Resolved

/// Early Stopping technique to stop training when accuracy stops improving.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Early Stopping technique to stop training when accuracy stops improving.", SortOrder = 15)]
public EarlyStopping EarlyStopper;
Copy link
Member

@codemzs codemzs Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EarlyStopper [](start = 33, length = 12)

We generally refer this as EarlyStoppingCriteria #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed


In reply to: 330318020 [](ancestors = 330318020)

epoch: 50,
batchSize: 5,
learningRate: 0.01f,
earlyStopping: new ImageClassificationEstimator.EarlyStopping(),
Copy link
Member

@codemzs codemzs Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

earlyStopping: new ImageClassificationEstimator.EarlyStopping(), [](start = 16, length = 64)

isn't this the case by default? How is this test different from the above test? May be in the above test disable early stopping so we get that case covered and here enable but also verify the epoch at which it stops via metrics callback #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


In reply to: 330318498 [](ancestors = 330318498)

Copy link
Member

@codemzs codemzs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

@ashbhandare ashbhandare merged commit f8a672a into dotnet:master Oct 2, 2019
@ashbhandare ashbhandare deleted the early_stopping branch October 2, 2019 17:27
@codemzs codemzs mentioned this pull request Oct 3, 2019
@ghost ghost locked as resolved and limited conversation to collaborators Mar 20, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Need Early Stopping feature in Image Classification
4 participants