Skip to content

Commit

Permalink
fixed unit test, minor formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ashbhandare committed Oct 23, 2019
1 parent cf670dc commit d2e1dcb
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 42 deletions.
50 changes: 23 additions & 27 deletions src/Microsoft.ML.Dnn/DnnUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
using System.IO;
using System.IO.Compression;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Security.AccessControl;
using System.Security.Principal;
using System.Threading.Tasks;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Tensorflow;
Expand Down Expand Up @@ -93,6 +91,7 @@ internal static Session LoadTFSession(IExceptionContext ectx, byte[] modelBytes,
}
return new Session(graph);
}

internal static void DownloadIfNeeded(Uri address, string fileName)
{
using HttpClient client = new HttpClient();
Expand Down Expand Up @@ -305,33 +304,30 @@ internal static DnnModel LoadDnnModel(IHostEnvironment env, string modelPath, bo
internal static DnnModel LoadDnnModel(IHostEnvironment env, ImageClassificationEstimator.Architecture arch, bool metaGraph = false)
{
var modelPath = ImageClassificationEstimator.ModelLocation[arch];
if (!File.Exists(modelPath))
if (arch == ImageClassificationEstimator.Architecture.InceptionV3)
{
if (arch == ImageClassificationEstimator.Architecture.InceptionV3)
{
var baseGitPath = @"https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta";
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"InceptionV3.meta");
var baseGitPath = @"https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta";
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"InceptionV3.meta");

baseGitPath = @"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/tfhub_modules.zip";
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"tfhub_modules.zip");
if (!Directory.Exists(@"tfhub_modules"))
ZipFile.ExtractToDirectory(Path.Combine(Directory.GetCurrentDirectory(), @"tfhub_modules.zip"), @"tfhub_modules");
}
else if (arch == ImageClassificationEstimator.Architecture.ResnetV2101)
{
var baseGitPath = @"https://aka.ms/mlnet-resources/image/ResNet101Tensorflow/resnet_v2_101_299.meta";
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"resnet_v2_101_299.meta");
}
else if (arch == ImageClassificationEstimator.Architecture.MobilenetV2)
{
var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/MobileNetV2TensorFlow/mobilenet_v2.meta";
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"mobilenet_v2.meta");
}
else if (arch == ImageClassificationEstimator.Architecture.ResnetV250)
{
var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/ResNetV250TensorFlow/resnet_v2_50_299.meta";
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"resnet_v2_50_299.meta");
}
baseGitPath = @"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/tfhub_modules.zip";
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"tfhub_modules.zip");
if (!Directory.Exists(@"tfhub_modules"))
ZipFile.ExtractToDirectory(Path.Combine(Directory.GetCurrentDirectory(), @"tfhub_modules.zip"), @"tfhub_modules");
}
else if (arch == ImageClassificationEstimator.Architecture.ResnetV2101)
{
var baseGitPath = @"https://aka.ms/mlnet-resources/image/ResNet101Tensorflow/resnet_v2_101_299.meta";
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"resnet_v2_101_299.meta");
}
else if (arch == ImageClassificationEstimator.Architecture.MobilenetV2)
{
var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/MobileNetV2TensorFlow/mobilenet_v2.meta";
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"mobilenet_v2.meta");
}
else if (arch == ImageClassificationEstimator.Architecture.ResnetV250)
{
var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/ResNetV250TensorFlow/resnet_v2_50_299.meta";
DownloadIfNeeded(new Uri($"{baseGitPath}"), @"resnet_v2_50_299.meta");
}

var session = GetSession(env, modelPath, metaGraph);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ protected override bool IsEnvironmentSupported()
return Environment.Is64BitProcess;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1820,22 +1820,21 @@ public void TensorflowRedownloadModelFile(ImageClassificationEstimator.Architect
using (File.Create(@"InceptionV3.meta")) { }
}

var options = new ImageClassificationEstimator.Options()
{
FeaturesColumnName = "Image",
LabelColumnName = "Label",
Arch = arch,
Epoch = 1,
BatchSize = 10,
MetricsCallback = (metrics) => Console.WriteLine(metrics),
TestOnTrainSet = false,
DisableEarlyStopping = true
};

//Create pipeline and run
var pipeline = mlContext.Transforms.LoadImages("Image", fullImagesetFolderPath, false, "ImagePath") // false indicates we want the image as a VBuffer<byte>
.Append(mlContext.Model.ImageClassification(
"Image", "Label",
// Just by changing/selecting InceptionV3 here instead of
// ResnetV2101 you can try a different architecture/pre-trained
// model.
arch: arch,
epoch: 1,
batchSize: 10,
learningRate: 0.01f,
metricsCallback: (metrics) => Console.WriteLine(metrics),
testOnTrainSet: false,
disableEarlyStopping: true,
reuseTrainSetBottleneckCachedValues: true,
reuseValidationSetBottleneckCachedValues: true)
.Append(mlContext.Model.ImageClassification(options)
.Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel", inputColumnName: "PredictedLabel")));

var trainedModel = pipeline.Fit(trainDataset);
Expand Down

0 comments on commit d2e1dcb

Please sign in to comment.