Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added onnx export support for VectorWhitening #4577

Merged
merged 4 commits into from
Jan 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion src/Microsoft.ML.Mkl.Components/VectorWhitening.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;

Expand Down Expand Up @@ -546,7 +547,7 @@ public static extern int Svd(Layout layout, SvdJob jobu, SvdJob jobvt,
private protected override IRowMapper MakeRowMapper(DataViewSchema schema)
=> new Mapper(this, schema);

private sealed class Mapper : OneToOneMapperBase
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private readonly VectorWhiteningTransformer _parent;
private readonly int[] _cols;
Expand Down Expand Up @@ -607,6 +608,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
// Notice that here that the learned matrices in _models will have the same size for both PCA and ZCA,
// so we perform a truncation of the matrix in FillValues, that only keeps PcaNum columns.
int cslotDst = (ex.Kind == WhiteningKind.PrincipalComponentAnalysis && ex.Rank > 0) ? ex.Rank : cslotSrc;

var model = _parent._models[iinfo];
ValueGetter<VBuffer<float>> del =
(ref VBuffer<float> dst) =>
Expand All @@ -618,6 +620,51 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
return del;
}

public bool CanSaveOnnx(OnnxContext ctx) => true;
Lynx1820 marked this conversation as resolved.
Show resolved Hide resolved

public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
int numColumns = _parent.ColumnPairs.Length;
for (int iinfo = 0; iinfo < numColumns; ++iinfo)
{
string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName;
if (!ctx.ContainsColumn(inputColumnName))
continue;

string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName;
string srcVariableName = ctx.GetVariableName(inputColumnName);
string dstVariableName = ctx.AddIntermediateVariable(_srcTypes[iinfo], outputColumnName, true);
SaveAsOnnxCore(ctx, iinfo, srcVariableName, dstVariableName);
}
}

private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
Lynx1820 marked this conversation as resolved.
Show resolved Hide resolved
{
var model = _parent._models[iinfo];
int dimension = _srcTypes[iinfo].GetValueCount();
Host.Assert(model.Length == dimension * dimension);

var parameters = _parent._columns[iinfo];
Host.Assert(parameters.Kind == WhiteningKind.PrincipalComponentAnalysis || parameters.Kind == WhiteningKind.ZeroPhaseComponentAnalysis);

int rank = (parameters.Kind == WhiteningKind.PrincipalComponentAnalysis && parameters.Rank > 0) ? parameters.Rank : dimension;
Host.CheckParam(rank <= dimension, nameof(rank), "Rank must be at most the dimension of untransformed data.");

long[] modelDimension = { rank, dimension };

var opType = "Gemm";
Lynx1820 marked this conversation as resolved.
Show resolved Hide resolved
var modelName = ctx.AddInitializer(model.Take(rank * dimension), modelDimension, "model");
var zeroValueName = ctx.AddInitializer((float)0);

var gemmOutput = ctx.AddIntermediateVariable(null, "GemmOutput", true);
Lynx1820 marked this conversation as resolved.
Show resolved Hide resolved
var node = ctx.CreateNode(opType, new[] { modelName, srcVariableName, zeroValueName }, new[] { gemmOutput }, ctx.GetNodeName(opType), "");
node.AddAttribute("transB", 1);

opType = "Transpose";
ctx.CreateNode(opType, new[] { gemmOutput }, new[] { dstVariableName }, ctx.GetNodeName(opType), "");
}

private ValueGetter<T> GetSrcGetter<T>(DataViewRow input, int iinfo)
{
Host.AssertValue(input);
Expand Down
34 changes: 34 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,40 @@ public void BinaryClassificationTrainersOnnxConversionTest()
Done();
}

[Fact]
public void TestVectorWhiteningOnnxConversionTest()
{
var mlContext = new MLContext(seed: 1);
string dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
var dataView = mlContext.Data.LoadFromTextFile(dataPath, new[] {
new TextLoader.Column("label", DataKind.Single, 11),
new TextLoader.Column("features", DataKind.Single, 0, 10)
}, hasHeader: true, separatorChar: ';');

var pipeline = new VectorWhiteningEstimator(mlContext, "whitened1", "features")
.Append(new VectorWhiteningEstimator(mlContext, "whitened2", "features", kind: WhiteningKind.PrincipalComponentAnalysis, rank: 5));
var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

// Compare model scores produced by ML.NET and ONNX's runtime.
if (IsOnnxRuntimeSupported())
{
var onnxFileName = $"VectorWhitening.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedR4VectorColumns(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult); // whitened1
CompareSelectedR4VectorColumns(transformedData.Schema[3].Name, outputNames[3], transformedData, onnxResult); // whitened2
}
Done();
}

private class DataPoint
{
[VectorType(3)]
Expand Down