Skip to content

Commit

Permalink
draft regression test
Browse files Browse the repository at this point in the history
  • Loading branch information
Lynx1820 committed Nov 8, 2019
1 parent b26092e commit 81381e2
Showing 1 changed file with 241 additions and 14 deletions.
255 changes: 241 additions & 14 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,52 @@
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using Microsoft.ML.Transforms.Onnx;
using Microsoft.ML.Transforms.Text;
using Newtonsoft.Json;
using Xunit;
using Xunit.Abstractions;
using static Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper;

namespace Microsoft.ML.Tests
{
public class OnnxConversionTest : BaseTestBaseline

public class OnnxConversionTest : BaseTestBaseline
{

private static IEnumerable<DataPoint2> GenerateRandomDataPoints(int count,
int seed = 0)
{
var random = new Random(seed);
for (int i = 0; i < count; i++)
{
float label = (float)random.NextDouble();
yield return new DataPoint2
{
Label = label,
// Create random features that are correlated with the label.
Features = Enumerable.Repeat(label, 50).Select(
x => x + (float)random.NextDouble()).ToArray()
};
}
}

// Example with label and 50 feature values. A data set is a collection of
// such examples.
private class DataPoint2
{
public float Label { get; set; }
[VectorType(50)]
public float[] Features { get; set; }
}

// Class used to capture predictions.
private class Prediction
{
// Original label.
public float Label { get; set; }
// Predicted score from the trainer.
public float Score { get; set; }
}
private class AdultData
{
[LoadColumn(0, 10), ColumnName("FeatureVector")]
Expand Down Expand Up @@ -108,8 +145,7 @@ public void SimpleEndToEndOnnxConversionTest()
private class BreastCancerFeatureVector
{
[LoadColumn(1, 9), VectorType(9)]
public float[] Features;
}
public float[] Features; }

private class BreastCancerCatFeatureExample
{
Expand Down Expand Up @@ -187,7 +223,160 @@ public void KmeansOnnxConversionTest()
Done();
}

private class DataPoint
[Fact]
public void WordEmbeddingEstimatorOnnxConversionTest() //can't find the class - maybe
{
// Step 1: Create and train a ML.NET pipeline.
var mlContext = new MLContext(seed: 1);
string dataPath = GetDataPath(TestDatasets.Sentiment.trainFilename);
var data = new TextLoader(ML,
new TextLoader.Options()
{
Separator = "\t",
HasHeader = true,
Columns = new[]
{
new TextLoader.Column("Label", DataKind.Boolean, 0),
new TextLoader.Column("SentimentText", DataKind.String, 1)
}
}).Load(GetDataPath(dataPath));

IEstimator<ITransformer>[] estimators = { };
var textPipeline = mlContext.Transforms.Text.NormalizeText("SentimentText")
.Append(mlContext.Transforms.Text.TokenizeIntoWords("Tokens",
"SentimentText"))
.Append(mlContext.Transforms.Text.ApplyWordEmbedding("Features",
"Tokens", WordEmbeddingEstimator.PretrainedModelKind
.SentimentSpecificWordEmbedding));
var model = textPipeline.Fit(data);
var transformedData = model.Transform(data);

var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, data);
// Compare results produced by ML.NET and ONNX's runtime.
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess)
{
var onnxFileName = "WordEmbeddingEstimator.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);

// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(data);
var onnxResult = onnxTransformer.Transform(data);
CompareSelectedR4VectorColumns("Score", "Score0", transformedData, onnxResult, 3);
}
Done();
}

[Fact]
// Conversion tests for regression
public void regressionOnnxConversionTest()
{
/*
var mlContext = new MLContext(seed: 1);
string dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
var dataView = mlContext.Data.LoadFromTextFile<AdultData>(dataPath,
separatorChar: ';',
hasHeader: true);
IEstimator<ITransformer>[] estimators = {
//mlContext.Regression.Trainers.Ols(new OlsTrainer.Options() {
// LabelColumnName = "Target",
// FeatureColumnName = "FeatureVector",
//}),
//mlContext.Regression.Trainers.OnlineGradientDescent(new OnlineGradientDescentTrainer.Options(){
// LabelColumnName = "Target",
// FeatureColumnName = "FeatureVector",
//}),
//mlContext.Transforms.DetectAnomalyBySrCnn("Target","FeatureVector"), // needs separate data
mlContext.Regression.Trainers.FastForest("Target", "FeatureVector"),
//mlContext.Regression.Trainers.FastTree("Target", "FeatureVector"),
//mlContext.Regression.Trainers.FastTreeTweedie("Target", "FeatureVector"),
//mlContext.Regression.Trainers.LightGbm("Target","FeatureVector"),
//mlContext.Regression.Trainers.LbfgsPoissonRegression("Target", "FeatureVector"),
};
*/
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness. Setting the seed to a fixed number
// in this example to make outputs deterministic.
var mlContext = new MLContext(seed: 0);

// Create a list of training data points.
var dataPoints = GenerateRandomDataPoints(1000);

// Convert the list of data points to an IDataView object, which is
// consumable by ML.NET API.
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);

// Define the trainer.
var pipeline = mlContext.Regression.Trainers.FastTreeTweedie(
labelColumnName: nameof(DataPoint2.Label),
featureColumnName: nameof(DataPoint2.Features));

// Train the model.
var model = pipeline.Fit(trainingData);

// Create testing data. Use different random seed to make it different
// from training data.
var data = mlContext.Data.LoadFromEnumerable(
GenerateRandomDataPoints(5, seed: 123));

// Run the model on test data set.
var transformedTestData = model.Transform(data);
// Convert IDataView object to a list.
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, data);
// Convert IDataView object to a list.
var predictions = mlContext.Data.CreateEnumerable<Prediction>(
transformedTestData, reuseRowObject: false).ToList();
foreach (var p in predictions)
System.Diagnostics.Debug.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}");
// Compare results produced by ML.NET and ONNX's runtime.
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess)
{
var onnxFileName = "test.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);

// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(data);
var onnxResult = onnxTransformer.Transform(data);
CompareSelectedR4ScalarColumns("Label", "Score0", data, onnxResult, 3);
}
Done();
/*var initialPipeline = mlContext.Transforms.NormalizeMinMax("FeatureVector");
foreach (var estimator in estimators)
{
//var pipeline = initialPipeline.Append(estimator);
var pipeline = estimator;
var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
var onnxFileName = $"{estimator.ToString()}.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);
// Compare model scores produced by ML.NET and ONNX's runtime.
if (IsOnnxRuntimeSupported())
{
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView); //switched to 2 vause
CompareSelectedR4ScalarColumns(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult, 0); // compare score results
}
} */
//Done();
}
private class DataPoint
{
[VectorType(3)]
public float[] Features { get; set; }
Expand Down Expand Up @@ -380,8 +569,7 @@ public void LogisticRegressionOnnxConversionTest()
var trainDataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
var mlContext = new MLContext(seed: 1);
var data = mlContext.Data.LoadFromTextFile<AdultData>(trainDataPath,
separatorChar: ';'
,
separatorChar: ';',
hasHeader: true);
var cachedTrainData = mlContext.Data.Cache(data);
var dynamicPipeline =
Expand Down Expand Up @@ -658,15 +846,21 @@ public void WordEmbeddingsTest()
var model = pipeline.Fit(data);
var transformedData = model.Transform(data);

var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "Transforms", "Sentiment");
var onnxTextName = "SmallWordEmbed.txt";
var onnxFileName = "SmallWordEmbed.onnx";
var onnxTextPath = GetOutputPath(subDir, onnxTextName);
var onnxFilePath = GetOutputPath(subDir, onnxFileName);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, data);
SaveOnnxModel(onnxModel, onnxFilePath, onnxTextPath);
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess)
{
var onnxFileName = "WordEmbeddingEstimator.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);

CheckEquality(subDir, onnxTextName, parseOption: NumberParseOption.UseSingle);
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(data);
var onnxResult = onnxTransformer.Transform(data);
CompareSelectedR4VectorColumns("Embed", "Embed0", transformedData, onnxResult);
}
Done();
}

Expand Down Expand Up @@ -984,11 +1178,44 @@ private void CompareSelectedR4ScalarColumns(string leftColumnName, string rightC

// Scalar such as R4 (float) is converted to [1, 1]-tensor in ONNX format for consitency of making batch prediction.
Assert.Equal(1, actual.Length);
Assert.Equal(expected, actual.GetItemOrDefault(0), precision);
//Assert.Equal(expected, actual.GetItemOrDefault(0), precision);
//Output.WriteLine(actual.GetItemOrDefault(0));
System.Diagnostics.Debug.WriteLine("Actual: " + actual.GetItemOrDefault(0));
System.Diagnostics.Debug.WriteLine("Expected: " + expected);
}
}
}

private void CompareSelectedScalarColumns<T>(string leftColumnName, string rightColumnName, IDataView left, IDataView right)
{
var leftColumn = left.Schema[leftColumnName];
var rightColumn = right.Schema[rightColumnName];

using (var expectedCursor = left.GetRowCursor(leftColumn))
using (var actualCursor = right.GetRowCursor(rightColumn))
{
T expected = default;
VBuffer<T> actual = default;
var expectedGetter = expectedCursor.GetGetter<T>(leftColumn);
var actualGetter = actualCursor.GetGetter<VBuffer<T>>(rightColumn);
while (expectedCursor.MoveNext() && actualCursor.MoveNext())
{
expectedGetter(ref expected);
actualGetter(ref actual);
var actualVal = actual.GetItemOrDefault(0);

Assert.Equal(1, actual.Length);

if (typeof(T) == typeof(ReadOnlyMemory<Char>))
Assert.Equal(expected.ToString(), actualVal.ToString());
else
Assert.Equal(expected, actualVal);
}
}
}



private void SaveOnnxModel(ModelProto model, string binaryFormatPath, string textFormatPath)
{
DeleteOutputPath(binaryFormatPath); // Clean if such a file exists.
Expand Down

0 comments on commit 81381e2

Please sign in to comment.