Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Defaults for ImageClassification API #4415

Merged
merged 7 commits into from
Nov 1, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ public static void Example()
//Download the image set and unzip
string finalImagesFolderName = DownloadImageSet(
imagesDownloadFolderPath);

string fullImagesetFolderPath = Path.Combine(
imagesDownloadFolderPath, finalImagesFolderName);

try
{

MLContext mlContext = new MLContext(seed: 1);
mlContext.Log += MlContext_Log;

//Load all the original images info
IEnumerable<ImageData> images = LoadImagesFromDirectory(
Expand All @@ -60,7 +62,7 @@ public static void Example()
IDataView testDataset = trainTestData.TestSet;

var pipeline = mlContext.MulticlassClassification.Trainers
.ImageClassification(featureColumnName:"Image", validationSet:testDataset)
.ImageClassification(featureColumnName:"Image")
.Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel",
inputColumnName: "PredictedLabel"));

Expand Down Expand Up @@ -109,6 +111,14 @@ public static void Example()
Console.ReadKey();
}

private static void MlContext_Log(object sender, LoggingEventArgs e)
{
if (e.Message.StartsWith("[Source=ImageClassificationTrainer;"))
{
Console.WriteLine(e.Message);
}
}

private static void TrySinglePrediction(string imagesForPredictions,
MLContext mlContext, ITransformer trainedModel)
{
Expand Down
10 changes: 8 additions & 2 deletions src/Microsoft.ML.Vision/ImageClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ public sealed class Options : TrainerInputBaseWithLabel
/// Early stopping technique parameters to be used to terminate training when training metric stops improving.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Early stopping technique parameters to be used to terminate training when training metric stops improving.", SortOrder = 15)]
public EarlyStopping EarlyStoppingCriteria;
public EarlyStopping EarlyStoppingCriteria = new EarlyStopping();

/// <summary>
/// Specifies the model architecture to be used in the case of image classification training using transfer learning.
Expand Down Expand Up @@ -437,7 +437,7 @@ public sealed class Options : TrainerInputBaseWithLabel
/// A class that performs learning rate scheduling.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "A class that performs learning rate scheduling.", SortOrder = 15)]
public LearningRateScheduler LearningRateScheduler = new LsrDecay();
public LearningRateScheduler LearningRateScheduler = new ExponentialLRDecay();
}

/// <summary> Return the type of prediction task.</summary>
Expand Down Expand Up @@ -532,6 +532,12 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
options.ValidationSetBottleneckCachedValuesFileName = _options.ValidationSetBottleneckCachedValuesFileName;
}

if (options.MetricsCallback == null)
{
var logger = Host.Start(nameof(ImageClassificationTrainer));
options.MetricsCallback = (ImageClassificationMetrics metric) => { logger.Trace(metric.ToString()); };
}

_options = options;
_useLRScheduling = _options.LearningRateScheduler != null;
_checkpointPath = Path.Combine(_options.WorkspacePath, _options.FinalModelPrefix +
Expand Down