-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Defaults for ImageClassification API #4415
Changes from 5 commits
1947cfb
7bff31b
47ac5cd
80fbb8f
9c2a986
82ef427
abf35e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,13 +29,15 @@ public static void Example() | |
//Download the image set and unzip | ||
string finalImagesFolderName = DownloadImageSet( | ||
imagesDownloadFolderPath); | ||
//string finalImagesFolderName = "flower_photos"; | ||
string fullImagesetFolderPath = Path.Combine( | ||
imagesDownloadFolderPath, finalImagesFolderName); | ||
|
||
try | ||
{ | ||
|
||
MLContext mlContext = new MLContext(seed: 1); | ||
mlContext.Log += MlContext_Log; | ||
|
||
//Load all the original images info | ||
IEnumerable<ImageData> images = LoadImagesFromDirectory( | ||
|
@@ -60,7 +62,7 @@ public static void Example() | |
IDataView testDataset = trainTestData.TestSet; | ||
|
||
var pipeline = mlContext.MulticlassClassification.Trainers | ||
.ImageClassification(featureColumnName:"Image", validationSet:testDataset) | ||
.ImageClassification(featureColumnName:"Image", validationSet:null) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
no need, default is null. #Resolved |
||
.Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel", | ||
inputColumnName: "PredictedLabel")); | ||
|
||
|
@@ -109,6 +111,14 @@ public static void Example() | |
Console.ReadKey(); | ||
} | ||
|
||
private static void MlContext_Log(object sender, LoggingEventArgs e) | ||
{ | ||
if (e.Message.StartsWith("[Source=ImageClassificationTrainer;")) | ||
{ | ||
Console.WriteLine(e.Message); | ||
} | ||
} | ||
|
||
private static void TrySinglePrediction(string imagesForPredictions, | ||
MLContext mlContext, ITransformer trainedModel) | ||
{ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -359,7 +359,7 @@ public sealed class Options : TrainerInputBaseWithLabel | |
/// Early stopping technique parameters to be used to terminate training when training metric stops improving. | ||
/// </summary> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Early stopping technique parameters to be used to terminate training when training metric stops improving.", SortOrder = 15)] | ||
public EarlyStopping EarlyStoppingCriteria; | ||
public EarlyStopping EarlyStoppingCriteria = new EarlyStopping(); | ||
|
||
/// <summary> | ||
/// Specifies the model architecture to be used in the case of image classification training using transfer learning. | ||
|
@@ -437,7 +437,7 @@ public sealed class Options : TrainerInputBaseWithLabel | |
/// A class that performs learning rate scheduling. | ||
/// </summary> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "A class that performs learning rate scheduling.", SortOrder = 15)] | ||
public LearningRateScheduler LearningRateScheduler = new LsrDecay(); | ||
public LearningRateScheduler LearningRateScheduler = new ExponentialLRDecay(); | ||
} | ||
|
||
/// <summary> Return the type of prediction task.</summary> | ||
|
@@ -532,6 +532,12 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options) | |
options.ValidationSetBottleneckCachedValuesFileName = _options.ValidationSetBottleneckCachedValuesFileName; | ||
} | ||
|
||
if ( options.MetricsCallback == null ) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
space #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
space #Closed |
||
{ | ||
var logger = Host.Start(nameof(ImageClassificationTrainer)); | ||
options.MetricsCallback = (ImageClassificationMetrics metric) => { logger.Trace(metric.ToString()); }; | ||
} | ||
|
||
_options = options; | ||
_useLRScheduling = _options.LearningRateScheduler != null; | ||
_checkpointPath = Path.Combine(_options.WorkspacePath, _options.FinalModelPrefix + | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove #Closed