Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/Microsoft.ML.ImageAnalytics/ImageLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,17 @@ private Delegate MakeGetterImageDataViewType(DataViewRow input, int iinfo, Func<
{
Contracts.AssertValue(input);
Contracts.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
disposer = null;
var lastImage = default(Bitmap);

disposer = () =>
{
if (lastImage != null)
{
lastImage.Dispose();
lastImage = null;
}
};

var getSrc = input.GetGetter<ReadOnlyMemory<char>>(input.Schema[ColMapNewToOld[iinfo]]);
ReadOnlyMemory<char> src = default;
ValueGetter<Bitmap> del =
Expand Down Expand Up @@ -247,6 +257,8 @@ private Delegate MakeGetterImageDataViewType(DataViewRow input, int iinfo, Func<
if (dst.PixelFormat == System.Drawing.Imaging.PixelFormat.DontCare)
throw Host.Except($"Failed to load image {src.ToString()}.");
}

lastImage = dst;
};

return del;
Expand Down
1 change: 0 additions & 1 deletion src/Microsoft.ML.ImageAnalytics/ImageResizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
{
if (src != null)
{
src.Dispose();
src = null;
}
};
Expand Down
98 changes: 98 additions & 0 deletions test/Microsoft.ML.Tests/ImagesTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -908,5 +908,103 @@ private class DataPoint
[VectorType(InputSize)]
public double[] Features { get; set; }
}

public class InMemoryImage
{
[ImageType(229, 299)]
public Bitmap LoadedImage;
public string Label;

public static List<InMemoryImage> LoadFromTsv(MLContext mlContext, string tsvPath, string imageFolder)
{
var inMemoryImages = new List<InMemoryImage>();
var tsvFile = mlContext.Data.LoadFromTextFile(tsvPath, columns: new[]
{
new TextLoader.Column("ImagePath", DataKind.String, 0),
new TextLoader.Column("Label", DataKind.String, 1),
}
);

using (var cursor = tsvFile.GetRowCursorForAllColumns())
{
var pathBuffer = default(ReadOnlyMemory<char>);
var labelBuffer = default(ReadOnlyMemory<char>);
var pathGetter = cursor.GetGetter<ReadOnlyMemory<char>>(tsvFile.Schema["ImagePath"]);
var labelGetter = cursor.GetGetter<ReadOnlyMemory<char>>(tsvFile.Schema["Label"]);
while (cursor.MoveNext())
{
pathGetter(ref pathBuffer);
labelGetter(ref labelBuffer);

var label = labelBuffer.ToString();
var fileName = pathBuffer.ToString();
var imagePath = Path.Combine(imageFolder, fileName);

inMemoryImages.Add(
new InMemoryImage()
{
Label = label,
LoadedImage = (Bitmap)Image.FromFile(imagePath)
}
);
}
}

return inMemoryImages;

}
}

public class InMemoryImageOutput : InMemoryImage
{
[ImageType(100, 100)]
public Bitmap ResizedImage;
}

[Fact]
public void ResizeInMemoryImages()
{
var mlContext = new MLContext(seed: 1);
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
var dataObjects = InMemoryImage.LoadFromTsv(mlContext, dataFile, imageFolder);

var dataView = mlContext.Data.LoadFromEnumerable<InMemoryImage>(dataObjects);
var pipeline = mlContext.Transforms.ResizeImages("ResizedImage", 100, 100, nameof(InMemoryImage.LoadedImage));

// Check that the output is resized, and that it didn't resize the original image object
var model = pipeline.Fit(dataView);
var resizedDV = model.Transform(dataView);
var rowView = resizedDV.Preview().RowView;
var resizedImage = (Bitmap)rowView.First().Values.Last().Value;
Assert.Equal(100, resizedImage.Height);
Assert.NotEqual(100, dataObjects[0].LoadedImage.Height);

// Also check usage of prediction Engine
// And that the references to the original image objects aren't lost
var predEngine = mlContext.Model.CreatePredictionEngine<InMemoryImage, InMemoryImageOutput>(model);
for(int i = 0; i < dataObjects.Count(); i++)
{
var prediction = predEngine.Predict(dataObjects[i]);
Assert.Equal(100, prediction.ResizedImage.Height);
Assert.NotEqual(100, prediction.LoadedImage.Height);
Assert.True(prediction.LoadedImage == dataObjects[i].LoadedImage);
Assert.False(prediction.ResizedImage == dataObjects[i].LoadedImage);
}

// Check that the last in-memory image hasn't been disposed
// By running ResizeImageTransformer (see https://github.com/dotnet/machinelearning/issues/4126)
bool disposed = false;
try
{
int i = dataObjects.Last().LoadedImage.Height;
}
catch
{
disposed = true;
}

Assert.False(disposed, "The last in memory image had been disposed by running ResizeImageTransformer");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.Drawing;
using System.IO;
using System.IO.Compression;
using System.Linq;
Expand All @@ -18,6 +19,7 @@
using Microsoft.ML.Transforms;
using Microsoft.ML.Transforms.Image;
using Microsoft.ML.TensorFlow;
using InMemoryImage = Microsoft.ML.Tests.ImageTests.InMemoryImage;
using Xunit;
using Xunit.Abstractions;
using static Microsoft.ML.DataOperationsCatalog;
Expand Down Expand Up @@ -1126,6 +1128,35 @@ public void TensorFlowTransformCifarSavedModel()
}
}

// This test doesn't really check the values of the results
// Simply checks that CrossValidation is doable with in-memory images
// See issue https://github.com/dotnet/machinelearning/issues/4126
[TensorFlowFact]
public void TensorFlowTransformCifarCrossValidationWithInMemoryImages()
{
var modelLocation = "cifar_saved_model";
var mlContext = new MLContext(seed: 1);
using var tensorFlowModel = mlContext.Model.LoadTensorFlowModel(modelLocation);
var schema = tensorFlowModel.GetInputSchema();
Assert.True(schema.TryGetColumnIndex("Input", out int column));
var type = (VectorDataViewType)schema[column].Type;
var imageHeight = type.Dimensions[0];
var imageWidth = type.Dimensions[1];
var dataFile = GetDataPath("images/images.tsv");
var imageFolder = Path.GetDirectoryName(dataFile);
var dataObjects = InMemoryImage.LoadFromTsv(mlContext, dataFile, imageFolder);

var dataView = mlContext.Data.LoadFromEnumerable<InMemoryImage>(dataObjects);
var pipeline = mlContext.Transforms.ResizeImages("ResizedImage", imageWidth, imageHeight, nameof(InMemoryImage.LoadedImage))
.Append(mlContext.Transforms.ExtractPixels("Input", "ResizedImage", interleavePixelColors: true))
.Append(tensorFlowModel.ScoreTensorFlowModel("Output", "Input"))
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
.Append(mlContext.MulticlassClassification.Trainers.NaiveBayes("Label", "Output"));

var cross = mlContext.MulticlassClassification.CrossValidate(dataView, pipeline, 2);
Assert.Equal(2, cross.Count());
}

// This test has been created as result of https://github.com/dotnet/machinelearning/issues/2156.
[TensorFlowFact]
public void TensorFlowGettingSchemaMultipleTimes()
Expand Down