|
7 | 7 | using System.IO;
|
8 | 8 | using System.IO.Compression;
|
9 | 9 | using System.Linq;
|
10 |
| -using System.Net; |
11 | 10 | using System.Net.Http;
|
12 | 11 | using System.Security.AccessControl;
|
13 | 12 | using System.Security.Principal;
|
14 |
| -using System.Threading.Tasks; |
15 | 13 | using Microsoft.ML.Data;
|
16 | 14 | using Microsoft.ML.Runtime;
|
17 | 15 | using Tensorflow;
|
@@ -93,6 +91,7 @@ internal static Session LoadTFSession(IExceptionContext ectx, byte[] modelBytes,
|
93 | 91 | }
|
94 | 92 | return new Session(graph);
|
95 | 93 | }
|
| 94 | + |
96 | 95 | internal static void DownloadIfNeeded(Uri address, string fileName)
|
97 | 96 | {
|
98 | 97 | using HttpClient client = new HttpClient();
|
@@ -305,33 +304,30 @@ internal static DnnModel LoadDnnModel(IHostEnvironment env, string modelPath, bo
|
305 | 304 | internal static DnnModel LoadDnnModel(IHostEnvironment env, ImageClassificationEstimator.Architecture arch, bool metaGraph = false)
|
306 | 305 | {
|
307 | 306 | var modelPath = ImageClassificationEstimator.ModelLocation[arch];
|
308 |
| - if (!File.Exists(modelPath)) |
| 307 | + if (arch == ImageClassificationEstimator.Architecture.InceptionV3) |
309 | 308 | {
|
310 |
| - if (arch == ImageClassificationEstimator.Architecture.InceptionV3) |
311 |
| - { |
312 |
| - var baseGitPath = @"https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta"; |
313 |
| - DownloadIfNeeded(new Uri($"{baseGitPath}"), @"InceptionV3.meta"); |
| 309 | + var baseGitPath = @"https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta"; |
| 310 | + DownloadIfNeeded(new Uri($"{baseGitPath}"), @"InceptionV3.meta"); |
314 | 311 |
|
315 |
| - baseGitPath = @"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/tfhub_modules.zip"; |
316 |
| - DownloadIfNeeded(new Uri($"{baseGitPath}"), @"tfhub_modules.zip"); |
317 |
| - if (!Directory.Exists(@"tfhub_modules")) |
318 |
| - ZipFile.ExtractToDirectory(Path.Combine(Directory.GetCurrentDirectory(), @"tfhub_modules.zip"), @"tfhub_modules"); |
319 |
| - } |
320 |
| - else if (arch == ImageClassificationEstimator.Architecture.ResnetV2101) |
321 |
| - { |
322 |
| - var baseGitPath = @"https://aka.ms/mlnet-resources/image/ResNet101Tensorflow/resnet_v2_101_299.meta"; |
323 |
| - DownloadIfNeeded(new Uri($"{baseGitPath}"), @"resnet_v2_101_299.meta"); |
324 |
| - } |
325 |
| - else if (arch == ImageClassificationEstimator.Architecture.MobilenetV2) |
326 |
| - { |
327 |
| - var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/MobileNetV2TensorFlow/mobilenet_v2.meta"; |
328 |
| - DownloadIfNeeded(new Uri($"{baseGitPath}"), @"mobilenet_v2.meta"); |
329 |
| - } |
330 |
| - else if (arch == ImageClassificationEstimator.Architecture.ResnetV250) |
331 |
| - { |
332 |
| - var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/ResNetV250TensorFlow/resnet_v2_50_299.meta"; |
333 |
| - DownloadIfNeeded(new Uri($"{baseGitPath}"), @"resnet_v2_50_299.meta"); |
334 |
| - } |
| 312 | + baseGitPath = @"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/tfhub_modules.zip"; |
| 313 | + DownloadIfNeeded(new Uri($"{baseGitPath}"), @"tfhub_modules.zip"); |
| 314 | + if (!Directory.Exists(@"tfhub_modules")) |
| 315 | + ZipFile.ExtractToDirectory(Path.Combine(Directory.GetCurrentDirectory(), @"tfhub_modules.zip"), @"tfhub_modules"); |
| 316 | + } |
| 317 | + else if (arch == ImageClassificationEstimator.Architecture.ResnetV2101) |
| 318 | + { |
| 319 | + var baseGitPath = @"https://aka.ms/mlnet-resources/image/ResNet101Tensorflow/resnet_v2_101_299.meta"; |
| 320 | + DownloadIfNeeded(new Uri($"{baseGitPath}"), @"resnet_v2_101_299.meta"); |
| 321 | + } |
| 322 | + else if (arch == ImageClassificationEstimator.Architecture.MobilenetV2) |
| 323 | + { |
| 324 | + var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/MobileNetV2TensorFlow/mobilenet_v2.meta"; |
| 325 | + DownloadIfNeeded(new Uri($"{baseGitPath}"), @"mobilenet_v2.meta"); |
| 326 | + } |
| 327 | + else if (arch == ImageClassificationEstimator.Architecture.ResnetV250) |
| 328 | + { |
| 329 | + var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/ResNetV250TensorFlow/resnet_v2_50_299.meta"; |
| 330 | + DownloadIfNeeded(new Uri($"{baseGitPath}"), @"resnet_v2_50_299.meta"); |
335 | 331 | }
|
336 | 332 |
|
337 | 333 | var session = GetSession(env, modelPath, metaGraph);
|
|
0 commit comments