Skip to content

Added Onnx export functionality to PCATransformer #4188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, s
model.IrVersion = (long)OnnxCSharpToProtoWrapper.Version.IrVersion;
model.ModelVersion = modelVersion;
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 1 });
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 7 });
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 9 });
model.Graph = new GraphProto();
var graph = model.Graph;
graph.Node.Add(nodes);
Expand Down
70 changes: 69 additions & 1 deletion src/Microsoft.ML.PCA/PcaTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Numeric;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
Expand Down Expand Up @@ -511,7 +512,7 @@ internal static void ValidatePcaInput(IExceptionContext ectx, string name, DataV
throw ectx.ExceptSchemaMismatch(nameof(inputSchema), "input", name, "known-size vector of Single of two or more items", type.ToString());
}

private sealed class Mapper : OneToOneMapperBase
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
public sealed class ColumnSchemaInfo
{
Expand Down Expand Up @@ -596,6 +597,73 @@ private static void TransformFeatures(IExceptionContext ectx, in VBuffer<float>

dst = editor.Commit();
}

public bool CanSaveOnnx(OnnxContext ctx) => true;

public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));

for (int i = 0; i < _numColumns; i++)
{
var colPair = _parent.ColumnPairs[i];
var transformInfo = _parent._transformInfos[i];
string inputColumnName = colPair.inputColumnName;
string outputColumnName = colPair.outputColumnName;
if (!ctx.ContainsColumn(inputColumnName))
{
ctx.RemoveColumn(colPair.outputColumnName, false);
continue;
}

var dstVariableName = ctx.AddIntermediateVariable(transformInfo.OutputType, outputColumnName);
SaveAsOnnxCore(ctx, i, ctx.GetVariableName(inputColumnName), dstVariableName);
}
}

private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
Host.CheckValue(ctx, nameof(ctx));

TransformInfo transformInfo = _parent._transformInfos[iinfo];
ColumnSchemaInfo schemaInfo = _parent._schemaInfos[iinfo];

float[] principalComponents = new float[transformInfo.Rank * transformInfo.Dimension];
for (int i = 0; i < transformInfo.Rank; i++)
{
Array.Copy(transformInfo.Eigenvectors[i], 0, principalComponents, i * transformInfo.Dimension, transformInfo.Dimension);
}
long[] pcaDims = { transformInfo.Rank, transformInfo.Dimension };
var pcaMatrix = ctx.AddInitializer(principalComponents, pcaDims, "principalComponents");

float[] zeroMean = new float[transformInfo.Rank];
if (transformInfo.MeanProjected != null)
{
Array.Copy(transformInfo.MeanProjected, zeroMean, transformInfo.Rank);
}

long[] meanDims = { transformInfo.Rank };
var zeroMeanNode = ctx.AddInitializer(zeroMean, meanDims, "meanVector");

// NB: Hack
// Currently ML.NET persists ONNX graphs in proto-buf 3 format but the Onnx runtime uses the proto-buf 2 format
// There is an incompatibility between the two where proto-buf 3 does not include variables whose values are zero
// In the Gemm node below, we want the srcVariableName matrix to be sent in without a transpose, so transA has to be zero
// Due to the incompatibility, we get an exception from the Onnx runtime
// To workaround this, we transpose the input data first with the Transpose operator and then use the Gemm operator with transA=1
// This should be removed once incompatibility is fixed.
string opType;
opType = "Transpose";
var transposeOutput = ctx.AddIntermediateVariable(schemaInfo.InputType, "TransposeOutput", true);
var transposeNode = ctx.CreateNode(opType, srcVariableName, transposeOutput, ctx.GetNodeName(opType), "");

opType = "Gemm";
var gemmNode = ctx.CreateNode(opType, new[] { transposeOutput, pcaMatrix, zeroMeanNode }, new[] { dstVariableName }, ctx.GetNodeName(opType), "");
gemmNode.AddAttribute("alpha", 1.0);
gemmNode.AddAttribute("beta", -1.0);
gemmNode.AddAttribute("transA", 1);
gemmNode.AddAttribute("transB", 1);
}
}

[TlcModule.EntryPoint(Name = "Transforms.PcaCalculator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@
"version": "1"
},
{
"version": "7"
"version": "9"
}
]
}
3 changes: 3 additions & 0 deletions test/Microsoft.ML.TestFramework/BaseTestBaseline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,9 @@ private bool MatchNumberWithTolerance(MatchCollection firstCollection, MatchColl

public bool CompareNumbersWithTolerance(double expected, double actual, int? iterationOnCollection = null, int digitsOfPrecision = DigitsOfPrecision)
{
if (double.IsNaN(expected) && double.IsNaN(actual))
return true;

// this follows the IEEE recommendations for how to compare floating point numbers
double allowedVariance = Math.Pow(10, -digitsOfPrecision);
double delta = Round(expected, digitsOfPrecision) - Round(actual, digitsOfPrecision);
Expand Down
49 changes: 48 additions & 1 deletion test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,46 @@ public void OnnxTypeConversionTest()
CompareResults(model.ColumnPairs[0].outputColumnName, outputNames[1], mlnetResult, onnxResult);
}
}
Done();
}

[Fact]
public void PcaOnnxConversionTest()
{
var dataSource = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);

var mlContext = new MLContext(seed: 1);
var dataView = mlContext.Data.LoadFromTextFile(dataSource, new[] {
new TextLoader.Column("label", DataKind.Single, 11),
new TextLoader.Column("features", DataKind.Single, 0, 10)
}, hasHeader: true, separatorChar: ';');

bool[] zeroMeans = { true, false };
foreach (var zeroMean in zeroMeans)
{
var pipeline = ML.Transforms.ProjectToPrincipalComponents("pca", "features", rank: 5, seed: 1, ensureZeroMean: zeroMean);
var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

var onnxFileName = "pca.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);

SaveOnnxModel(onnxModel, onnxModelPath, null);

if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess)
{
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedR4VectorColumns(model.ColumnPairs[0].outputColumnName, outputNames[2], transformedData, onnxResult);
}
}

Done();
}

private void CreateDummyExamplesToMakeComplierHappy()
Expand Down Expand Up @@ -845,7 +885,14 @@ private void CompareSelectedR4VectorColumns(string leftColumnName, string rightC

Assert.Equal(expected.Length, actual.Length);
for (int i = 0; i < expected.Length; ++i)
Assert.Equal(expected.GetItemOrDefault(i), actual.GetItemOrDefault(i), precision);
{
// We are using float values. But the Assert.Equal function takes doubles.
// And sometimes the converted doubles are different in their precision.
// So make sure we compare floats
float exp = expected.GetItemOrDefault(i);
float act = actual.GetItemOrDefault(i);
CompareNumbersWithTolerance(exp, act, null, precision);
}
}
}
}
Expand Down