Skip to content

fix test errors in unix machines #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: torch
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 41 additions & 32 deletions test/Microsoft.ML.Tests/Torch/TorchTests.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System.Collections.Generic;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using Microsoft.ML.Data;
using Microsoft.ML.RunTests;
using Microsoft.ML.TestFramework.Attributes;
Expand All @@ -24,53 +26,60 @@ private class TestReLUModelData
[TorchFact]
public void TorchScoringReLUTest()
{
var mlContext = new MLContext();
var tensor = new float[] { -1, -1, 0, 1, 1 }.ToTorchTensor(dimensions: new long[] { 5 });
var data = new TestReLUModelData
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
Features = tensor.Data<float>().ToArray()
};
var dataPoint = new List<TestReLUModelData>() { data };
var mlContext = new MLContext();
var tensor = new float[] { -1, -1, 0, 1, 1 }.ToTorchTensor(dimensions: new long[] { 5 });
var data = new TestReLUModelData
{
Features = tensor.Data<float>().ToArray()
};
var dataPoint = new List<TestReLUModelData>() { data };

var dataView = mlContext.Data.LoadFromEnumerable(dataPoint);
var dataView = mlContext.Data.LoadFromEnumerable(dataPoint);

var output = mlContext.Model
.LoadTorchModel(GetDataPath("Torch/relu.pt"))
.ScoreTorchModel("Features", new long[] { 5 })
.Fit(dataView)
.Transform(dataView);
var output = mlContext.Model
.LoadTorchModel(GetDataPath("Torch/relu.pt"))
.ScoreTorchModel("Features", new long[] { 5 })
.Fit(dataView)
.Transform(dataView);

var transformedData = mlContext.Data.CreateEnumerable<TestReLUModelData>(output, false).ToArray()[0].Features;
Assert.True(transformedData.Length == 5);
Assert.Equal(transformedData, new float[] { 0, 0, 0, 1, 1 });
var transformedData = mlContext.Data.CreateEnumerable<TestReLUModelData>(output, false).ToArray()[0].Features;

Assert.True(transformedData.Length == 5);
Assert.Equal(transformedData, new float[] { 0, 0, 0, 1, 1 });
}
}

[TorchFact]
public void TorchTransformerWorkoutTest()
{
var mlContext = new MLContext();
var tensorData = FloatTensor.Random(new long[] { 5 });
var datapoint = new TestReLUModelData
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
Features = tensorData.Data<float>().ToArray()
};
var data = new List<TestReLUModelData>() { datapoint, datapoint, datapoint, datapoint, datapoint };
var mlContext = new MLContext();
var tensorData = FloatTensor.Random(new long[] { 5 });
var datapoint = new TestReLUModelData
{
Features = tensorData.Data<float>().ToArray()
};
var data = new List<TestReLUModelData>() { datapoint, datapoint, datapoint, datapoint, datapoint };

var dataView = mlContext.Data.LoadFromEnumerable(data);
var dataView = mlContext.Data.LoadFromEnumerable(data);

var estimator = mlContext.Model.LoadTorchModel(GetDataPath("Torch/relu.pt"))
.ScoreTorchModel("TorchOutput", new long[] { 5 }, "Features");
var estimator = mlContext.Model.LoadTorchModel(GetDataPath("Torch/relu.pt"))
.ScoreTorchModel("TorchOutput", new long[] { 5 }, "Features");

TestEstimatorCore(estimator, dataView);
TestEstimatorCore(estimator, dataView);

var output = estimator.Fit(dataView)
.Transform(dataView);
var output = estimator.Fit(dataView)
.Transform(dataView);

var transformedData = mlContext.Data.CreateEnumerable<TestReLUModelData>(output, false).ToArray()[0].Features;
Assert.True(transformedData.Length == 5);
foreach (var elt in transformedData)
Assert.True(elt >= 0);
var transformedData = mlContext.Data.CreateEnumerable<TestReLUModelData>(output, false).ToArray()[0].Features;

Assert.True(transformedData.Length == 5);
foreach (var elt in transformedData)
Assert.True(elt >= 0);
}
}
}
}