Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified how DataViewTypes are registered #4187

Merged
merged 6 commits into from
Sep 13, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,9 @@
<PackageReference Include="Microsoft.ML.TestModels" Version="$(MicrosoftMLTestModelsPackageVersion)" />
<PackageReference Include="System.Data.SqlClient" Version="$(SystemDataSqlClientVersion)" />
</ItemGroup>
<ItemGroup>
<None Update="column_name_test\model.onnx">
bpstark marked this conversation as resolved.
Show resolved Hide resolved
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>
</Project>
71 changes: 71 additions & 0 deletions test/Microsoft.ML.Tests/OnnxSequenceTypeWithAttributesTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
using System;
bpstark marked this conversation as resolved.
Show resolved Hide resolved
using System.Collections.Generic;
using System.Drawing;
using System.Text;
using Microsoft.ML.Data;
using Microsoft.ML.RunTests;
using Microsoft.ML.Transforms.Image;
using Microsoft.ML.Transforms.Onnx;
using Xunit;
using Xunit.Abstractions;
using System.Linq;
using System.IO;

namespace Microsoft.ML.Tests
{
public class OnnxSequenceTypeWithAttributesTest : BaseTestBaseline
{
public class ImagePrediction
{
[ColumnName("classLabel")]
[VectorType]
public string[] Prediction;

[ColumnName("loss")]
[OnnxSequenceType(typeof(IDictionary<string, float>))]
public IEnumerable<IDictionary<string, float>> Loss;
}
public class ImageInput
{
[ImageType(224, 224)]
public Bitmap Image { get; set; }
}

public OnnxSequenceTypeWithAttributesTest(ITestOutputHelper output) : base(output)
{
}
public static PredictionEngine<ImageInput, ImagePrediction> LoadModel(string onnxModelFilePath)
{
var ctx = new MLContext();
var dataView = ctx.Data.LoadFromEnumerable(new List<ImageInput>());

var pipeline = ctx.Transforms.ResizeImages(
resizing: ImageResizingEstimator.ResizingKind.Fill,
outputColumnName: "data",
imageWidth: 224,
imageHeight: 224,
inputColumnName: nameof(ImageInput.Image))
.Append(ctx.Transforms.ExtractPixels(outputColumnName: "data"))
.Append(ctx.Transforms.ApplyOnnxModel(
modelFile: onnxModelFilePath,
outputColumnNames: new[] { "classLabel", "loss" }, inputColumnNames: new[] { "data" }));

var model = pipeline.Fit(dataView);
return ctx.Model.CreatePredictionEngine<ImageInput, ImagePrediction>(model);
}

[Fact]
public void OnnxSequenceTypeWithColumnNameAttributeTest()
{
var modelFile = @"column_name_test/model.onnx";
var predictor = LoadModel(modelFile);
string image_path = Path.Combine(DataDir, "images", "banana.jpg");

var output = predictor.Predict(new ImageInput { Image = (Bitmap)Image.FromFile(image_path) });
Assert.NotEmpty(output.Prediction);
var loss = output.Loss.FirstOrDefault();
Assert.NotEmpty(loss);
Assert.True(loss[output.Prediction[0]] > 0, "Invalid output");
}
}
}
Binary file not shown.