Skip to content

Commit d2e1dcb

Browse files
committed
fixed unit test, minor formatting
1 parent cf670dc commit d2e1dcb

File tree

3 files changed

+37
-42
lines changed

3 files changed

+37
-42
lines changed

src/Microsoft.ML.Dnn/DnnUtils.cs

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77
using System.IO;
88
using System.IO.Compression;
99
using System.Linq;
10-
using System.Net;
1110
using System.Net.Http;
1211
using System.Security.AccessControl;
1312
using System.Security.Principal;
14-
using System.Threading.Tasks;
1513
using Microsoft.ML.Data;
1614
using Microsoft.ML.Runtime;
1715
using Tensorflow;
@@ -93,6 +91,7 @@ internal static Session LoadTFSession(IExceptionContext ectx, byte[] modelBytes,
9391
}
9492
return new Session(graph);
9593
}
94+
9695
internal static void DownloadIfNeeded(Uri address, string fileName)
9796
{
9897
using HttpClient client = new HttpClient();
@@ -305,33 +304,30 @@ internal static DnnModel LoadDnnModel(IHostEnvironment env, string modelPath, bo
305304
internal static DnnModel LoadDnnModel(IHostEnvironment env, ImageClassificationEstimator.Architecture arch, bool metaGraph = false)
306305
{
307306
var modelPath = ImageClassificationEstimator.ModelLocation[arch];
308-
if (!File.Exists(modelPath))
307+
if (arch == ImageClassificationEstimator.Architecture.InceptionV3)
309308
{
310-
if (arch == ImageClassificationEstimator.Architecture.InceptionV3)
311-
{
312-
var baseGitPath = @"https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta";
313-
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"InceptionV3.meta");
309+
var baseGitPath = @"https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta";
310+
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"InceptionV3.meta");
314311

315-
baseGitPath = @"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/tfhub_modules.zip";
316-
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"tfhub_modules.zip");
317-
if (!Directory.Exists(@"tfhub_modules"))
318-
ZipFile.ExtractToDirectory(Path.Combine(Directory.GetCurrentDirectory(), @"tfhub_modules.zip"), @"tfhub_modules");
319-
}
320-
else if (arch == ImageClassificationEstimator.Architecture.ResnetV2101)
321-
{
322-
var baseGitPath = @"https://aka.ms/mlnet-resources/image/ResNet101Tensorflow/resnet_v2_101_299.meta";
323-
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"resnet_v2_101_299.meta");
324-
}
325-
else if (arch == ImageClassificationEstimator.Architecture.MobilenetV2)
326-
{
327-
var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/MobileNetV2TensorFlow/mobilenet_v2.meta";
328-
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"mobilenet_v2.meta");
329-
}
330-
else if (arch == ImageClassificationEstimator.Architecture.ResnetV250)
331-
{
332-
var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/ResNetV250TensorFlow/resnet_v2_50_299.meta";
333-
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"resnet_v2_50_299.meta");
334-
}
312+
baseGitPath = @"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/tfhub_modules.zip";
313+
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"tfhub_modules.zip");
314+
if (!Directory.Exists(@"tfhub_modules"))
315+
ZipFile.ExtractToDirectory(Path.Combine(Directory.GetCurrentDirectory(), @"tfhub_modules.zip"), @"tfhub_modules");
316+
}
317+
else if (arch == ImageClassificationEstimator.Architecture.ResnetV2101)
318+
{
319+
var baseGitPath = @"https://aka.ms/mlnet-resources/image/ResNet101Tensorflow/resnet_v2_101_299.meta";
320+
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"resnet_v2_101_299.meta");
321+
}
322+
else if (arch == ImageClassificationEstimator.Architecture.MobilenetV2)
323+
{
324+
var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/MobileNetV2TensorFlow/mobilenet_v2.meta";
325+
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"mobilenet_v2.meta");
326+
}
327+
else if (arch == ImageClassificationEstimator.Architecture.ResnetV250)
328+
{
329+
var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/ResNetV250TensorFlow/resnet_v2_50_299.meta";
330+
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"resnet_v2_50_299.meta");
335331
}
336332

337333
var session = GetSession(env, modelPath, metaGraph);

test/Microsoft.ML.TestFramework/Attributes/TensorflowTheoryAttribute.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ protected override bool IsEnvironmentSupported()
2020
return Environment.Is64BitProcess;
2121
}
2222
}
23-
}
23+
}

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1820,22 +1820,21 @@ public void TensorflowRedownloadModelFile(ImageClassificationEstimator.Architect
18201820
using (File.Create(@"InceptionV3.meta")) { }
18211821
}
18221822

1823+
var options = new ImageClassificationEstimator.Options()
1824+
{
1825+
FeaturesColumnName = "Image",
1826+
LabelColumnName = "Label",
1827+
Arch = arch,
1828+
Epoch = 1,
1829+
BatchSize = 10,
1830+
MetricsCallback = (metrics) => Console.WriteLine(metrics),
1831+
TestOnTrainSet = false,
1832+
DisableEarlyStopping = true
1833+
};
1834+
18231835
//Create pipeline and run
18241836
var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer<byte>
1825-
.Append(mlContext.Model.ImageClassification(
1826-
"Image", "Label",
1827-
// Just by changing/selecting InceptionV3 here instead of
1828-
// ResnetV2101 you can try a different architecture/pre-trained
1829-
// model.
1830-
arch: arch,
1831-
epoch: 1,
1832-
batchSize: 10,
1833-
learningRate: 0.01f,
1834-
metricsCallback: (metrics) => Console.WriteLine(metrics),
1835-
testOnTrainSet: false,
1836-
disableEarlyStopping: true,
1837-
reuseTrainSetBottleneckCachedValues: true,
1838-
reuseValidationSetBottleneckCachedValues: true)
1837+
.Append(mlContext.Model.ImageClassification(options)
18391838
.Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel", inputColumnName: "PredictedLabel")));
18401839

18411840
var trainedModel = pipeline.Fit(trainDataset);

0 commit comments

Comments
 (0)