Skip to content

Commit 3f94d39

Browse files
committed
fix ealy stopping sample
1 parent 0e1e778 commit 3f94d39

File tree

1 file changed

+19
-58
lines changed

1 file changed

+19
-58
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ResnetV2101TransferLearningEarlyStopping.cs

Lines changed: 19 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ public static void Example()
8989

9090
// Create the ImageClassification pipeline.
9191
var pipeline = mlContext.Transforms.LoadRawImageBytes(
92-
"Image", fullImagesetFolderPath, "ImagePath")
92+
"Image", fullImagesetFolderPath, "ImagePath")
9393
.Append(mlContext.MulticlassClassification.Trainers.
9494
ImageClassification(options));
9595

@@ -135,13 +135,17 @@ public static void Example()
135135
// Micro-accuracy: 0.851851851851852,macro-accuracy = 0.85
136136
EvaluateModel(mlContext, testDataset, loadedModel);
137137

138+
VBuffer<ReadOnlyMemory<char>> keys = default;
139+
loadedModel.GetOutputSchema(schema)["Label"].GetKeyValues(ref keys);
140+
138141
watch = System.Diagnostics.Stopwatch.StartNew();
139142

140143
// Predict on a single image class using an in-memory image.
141144
// Sample output:
142145
// Scores : [0.09683081,0.0002645972,0.007213613,0.8912219,0.004469037],
143146
// Predicted Label : daisy
144-
TrySinglePrediction(fullImagesetFolderPath, mlContext, loadedModel);
147+
TrySinglePrediction(fullImagesetFolderPath, mlContext, loadedModel,
148+
keys.DenseValues().ToArray());
145149

146150
watch.Stop();
147151
elapsedMs = watch.ElapsedMilliseconds;
@@ -160,28 +164,31 @@ public static void Example()
160164

161165
// Predict on a single image.
162166
private static void TrySinglePrediction(string imagesForPredictions,
163-
MLContext mlContext, ITransformer trainedModel)
167+
MLContext mlContext, ITransformer trainedModel,
168+
ReadOnlyMemory<char>[] originalLabels)
164169
{
165170
// Create prediction function to try one prediction.
166171
var predictionEngine = mlContext.Model
167-
.CreatePredictionEngine<InMemoryImageData,
168-
ImagePrediction>(trainedModel);
172+
.CreatePredictionEngine<ImageData, ImagePrediction>(trainedModel);
169173

170174
// Load test images
171-
IEnumerable<InMemoryImageData> testImages =
172-
LoadInMemoryImagesFromDirectory(imagesForPredictions, false);
175+
IEnumerable<ImageData> testImages = LoadImagesFromDirectory(
176+
imagesForPredictions, false);
173177

174178
// Create an in-memory image object from the first image in the test data.
175-
InMemoryImageData imageToPredict = new InMemoryImageData
179+
ImageData imageToPredict = new ImageData
176180
{
177-
Image = testImages.First().Image
181+
ImagePath = testImages.First().ImagePath
178182
};
179183

180184
// Predict on the single image.
181185
var prediction = predictionEngine.Predict(imageToPredict);
186+
var index = prediction.PredictedLabel;
182187

183-
Console.WriteLine($"Scores : [{string.Join(",", prediction.Score)}], " +
184-
$"Predicted Label : {prediction.PredictedLabel}");
188+
Console.WriteLine($"ImageFile : " +
189+
$"[{Path.GetFileName(imageToPredict.ImagePath)}], " +
190+
$"Scores : [{string.Join(",", prediction.Score)}], " +
191+
$"Predicted Label : {originalLabels[index]}");
185192
}
186193

187194
// Evaluate the trained model on the passed test dataset.
@@ -243,42 +250,6 @@ public static IEnumerable<ImageData> LoadImagesFromDirectory(string folder,
243250
}
244251
}
245252

246-
// Load In memory raw images from directory.
247-
public static IEnumerable<InMemoryImageData>
248-
LoadInMemoryImagesFromDirectory(string folder,
249-
bool useFolderNameAsLabel = true)
250-
{
251-
var files = Directory.GetFiles(folder, "*",
252-
searchOption: SearchOption.AllDirectories);
253-
foreach (var file in files)
254-
{
255-
if (Path.GetExtension(file) != ".jpg")
256-
continue;
257-
258-
var label = Path.GetFileName(file);
259-
if (useFolderNameAsLabel)
260-
label = Directory.GetParent(file).Name;
261-
else
262-
{
263-
for (int index = 0; index < label.Length; index++)
264-
{
265-
if (!char.IsLetter(label[index]))
266-
{
267-
label = label.Substring(0, index);
268-
break;
269-
}
270-
}
271-
}
272-
273-
yield return new InMemoryImageData()
274-
{
275-
Image = File.ReadAllBytes(file),
276-
Label = label
277-
};
278-
279-
}
280-
}
281-
282253
// Download and unzip the image dataset.
283254
public static string DownloadImageSet(string imagesDownloadFolder)
284255
{
@@ -367,16 +338,6 @@ public static string GetAbsolutePath(string relativePath)
367338
return fullPath;
368339
}
369340

370-
// InMemoryImageData class holding the raw image byte array and label.
371-
public class InMemoryImageData
372-
{
373-
[LoadColumn(0)]
374-
public byte[] Image;
375-
376-
[LoadColumn(1)]
377-
public string Label;
378-
}
379-
380341
// ImageData class holding the imagepath and label.
381342
public class ImageData
382343
{
@@ -394,7 +355,7 @@ public class ImagePrediction
394355
public float[] Score;
395356

396357
[ColumnName("PredictedLabel")]
397-
public string PredictedLabel;
358+
public UInt32 PredictedLabel;
398359
}
399360
}
400361
}

0 commit comments

Comments
 (0)