From b9c68bff79c7afa71c85499cdaf9c46426abbf52 Mon Sep 17 00:00:00 2001 From: harshithapv <54084812+harshithapv@users.noreply.github.com> Date: Fri, 1 Nov 2019 22:01:04 +0000 Subject: [PATCH] Defaults for ImageClassification API (#4415) * Changed some defaults * Changed metrics callback default * metricsCallback will write to mlcontext log by default. The sample has been update to show how to get the output to console from the log. * deleted unnecessary comments * Addressed comments * Minor clean up. * Disable unstable test. --- .../ImageClassificationDefault.cs | 12 +++++++++++- .../ImageClassificationTrainer.cs | 10 ++++++++-- .../TensorflowTests.cs | 2 +- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ImageClassificationDefault.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ImageClassificationDefault.cs index ad999599c9..f128ba7cdc 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ImageClassificationDefault.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ImageClassificationDefault.cs @@ -29,6 +29,7 @@ public static void Example() //Download the image set and unzip string finalImagesFolderName = DownloadImageSet( imagesDownloadFolderPath); + string fullImagesetFolderPath = Path.Combine( imagesDownloadFolderPath, finalImagesFolderName); @@ -36,6 +37,7 @@ public static void Example() { MLContext mlContext = new MLContext(seed: 1); + mlContext.Log += MlContext_Log; //Load all the original images info IEnumerable images = LoadImagesFromDirectory( @@ -60,7 +62,7 @@ public static void Example() IDataView testDataset = trainTestData.TestSet; var pipeline = mlContext.MulticlassClassification.Trainers - .ImageClassification(featureColumnName:"Image", validationSet:testDataset) + .ImageClassification(featureColumnName:"Image") .Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel", inputColumnName: "PredictedLabel")); @@ -109,6 +111,14 @@ public static void Example() Console.ReadKey(); } + private static void MlContext_Log(object sender, LoggingEventArgs e) + { + if (e.Message.StartsWith("[Source=ImageClassificationTrainer;")) + { + Console.WriteLine(e.Message); + } + } + private static void TrySinglePrediction(string imagesForPredictions, MLContext mlContext, ITransformer trainedModel) { diff --git a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs index 5958e9920d..8d421d77ba 100644 --- a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs +++ b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs @@ -359,7 +359,7 @@ public sealed class Options : TrainerInputBaseWithLabel /// Early stopping technique parameters to be used to terminate training when training metric stops improving. /// [Argument(ArgumentType.AtMostOnce, HelpText = "Early stopping technique parameters to be used to terminate training when training metric stops improving.", SortOrder = 15)] - public EarlyStopping EarlyStoppingCriteria; + public EarlyStopping EarlyStoppingCriteria = new EarlyStopping(); /// /// Specifies the model architecture to be used in the case of image classification training using transfer learning. @@ -437,7 +437,7 @@ public sealed class Options : TrainerInputBaseWithLabel /// A class that performs learning rate scheduling. /// [Argument(ArgumentType.AtMostOnce, HelpText = "A class that performs learning rate scheduling.", SortOrder = 15)] - public LearningRateScheduler LearningRateScheduler = new LsrDecay(); + public LearningRateScheduler LearningRateScheduler = new ExponentialLRDecay(); } /// Return the type of prediction task. @@ -532,6 +532,12 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options) options.ValidationSetBottleneckCachedValuesFileName = _options.ValidationSetBottleneckCachedValuesFileName; } + if (options.MetricsCallback == null) + { + var logger = Host.Start(nameof(ImageClassificationTrainer)); + options.MetricsCallback = (ImageClassificationMetrics metric) => { logger.Trace(metric.ToString()); }; + } + _options = options; _useLRScheduling = _options.LearningRateScheduler != null; _checkpointPath = Path.Combine(_options.WorkspacePath, _options.FinalModelPrefix + diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 3c3b078130..1b441f5326 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -1431,7 +1431,7 @@ public void TensorFlowImageClassificationWithExponentialLRScheduling() TensorFlowImageClassificationWithLRScheduling(new ExponentialLRDecay()); } - [TensorFlowFact] + [Fact(Skip ="Very unstable tests, causing many build failures.")] public void TensorFlowImageClassificationWithPolynomialLRScheduling() { TensorFlowImageClassificationWithLRScheduling(new PolynomialLRDecay());