Skip to content

Commit

Permalink
Defaults for ImageClassification API (#4415)
Browse files Browse the repository at this point in the history
* Changed some defaults

* Changed metrics callback default

* metricsCallback will write to mlcontext log by default. The sample has been update to show how to get the output to console from the log.

* deleted unnecessary comments

* Addressed comments

* Minor clean up.

* Disable unstable test.
  • Loading branch information
harshithapv authored and codemzs committed Nov 1, 2019
1 parent f341ca3 commit b9c68bf
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ public static void Example()
//Download the image set and unzip
string finalImagesFolderName = DownloadImageSet(
imagesDownloadFolderPath);

string fullImagesetFolderPath = Path.Combine(
imagesDownloadFolderPath, finalImagesFolderName);

try
{

MLContext mlContext = new MLContext(seed: 1);
mlContext.Log += MlContext_Log;

//Load all the original images info
IEnumerable<ImageData> images = LoadImagesFromDirectory(
Expand All @@ -60,7 +62,7 @@ public static void Example()
IDataView testDataset = trainTestData.TestSet;

var pipeline = mlContext.MulticlassClassification.Trainers
.ImageClassification(featureColumnName:"Image", validationSet:testDataset)
.ImageClassification(featureColumnName:"Image")
.Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel",
inputColumnName: "PredictedLabel"));

Expand Down Expand Up @@ -109,6 +111,14 @@ public static void Example()
Console.ReadKey();
}

private static void MlContext_Log(object sender, LoggingEventArgs e)
{
if (e.Message.StartsWith("[Source=ImageClassificationTrainer;"))
{
Console.WriteLine(e.Message);
}
}

private static void TrySinglePrediction(string imagesForPredictions,
MLContext mlContext, ITransformer trainedModel)
{
Expand Down
10 changes: 8 additions & 2 deletions src/Microsoft.ML.Vision/ImageClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ public sealed class Options : TrainerInputBaseWithLabel
/// Early stopping technique parameters to be used to terminate training when training metric stops improving.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Early stopping technique parameters to be used to terminate training when training metric stops improving.", SortOrder = 15)]
public EarlyStopping EarlyStoppingCriteria;
public EarlyStopping EarlyStoppingCriteria = new EarlyStopping();

/// <summary>
/// Specifies the model architecture to be used in the case of image classification training using transfer learning.
Expand Down Expand Up @@ -437,7 +437,7 @@ public sealed class Options : TrainerInputBaseWithLabel
/// A class that performs learning rate scheduling.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "A class that performs learning rate scheduling.", SortOrder = 15)]
public LearningRateScheduler LearningRateScheduler = new LsrDecay();
public LearningRateScheduler LearningRateScheduler = new ExponentialLRDecay();
}

/// <summary> Return the type of prediction task.</summary>
Expand Down Expand Up @@ -532,6 +532,12 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
options.ValidationSetBottleneckCachedValuesFileName = _options.ValidationSetBottleneckCachedValuesFileName;
}

if (options.MetricsCallback == null)
{
var logger = Host.Start(nameof(ImageClassificationTrainer));
options.MetricsCallback = (ImageClassificationMetrics metric) => { logger.Trace(metric.ToString()); };
}

_options = options;
_useLRScheduling = _options.LearningRateScheduler != null;
_checkpointPath = Path.Combine(_options.WorkspacePath, _options.FinalModelPrefix +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1431,7 +1431,7 @@ public void TensorFlowImageClassificationWithExponentialLRScheduling()
TensorFlowImageClassificationWithLRScheduling(new ExponentialLRDecay());
}

[TensorFlowFact]
[Fact(Skip ="Very unstable tests, causing many build failures.")]
public void TensorFlowImageClassificationWithPolynomialLRScheduling()
{
TensorFlowImageClassificationWithLRScheduling(new PolynomialLRDecay());
Expand Down

0 comments on commit b9c68bf

Please sign in to comment.