From 549b389766c540f21e5597e5c7b319cbd98570ee Mon Sep 17 00:00:00 2001 From: Harish Kulkarni Date: Wed, 27 Nov 2019 19:13:54 +0000 Subject: [PATCH] Added onnx export support for CopyColumns (#4486) --- .../Transforms/ColumnCopying.cs | 9 ++--- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 33 +++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs index bd08e9c211..28642c61da 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs @@ -197,7 +197,7 @@ private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx private readonly DataViewSchema _schema; private readonly (string outputColumnName, string inputColumnName)[] _columns; - public bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental; + public bool CanSaveOnnx(OnnxContext ctx) => true; internal Mapper(ColumnCopyingTransformer parent, DataViewSchema inputSchema, (string outputColumnName, string inputColumnName)[] columns) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) @@ -233,15 +233,16 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() public void SaveAsOnnx(OnnxContext ctx) { - var opType = "CSharp"; + var opType = "Identity"; foreach (var column in _columns) { var srcVariableName = ctx.GetVariableName(column.inputColumnName); + if (!ctx.ContainsColumn(srcVariableName)) + continue; _schema.TryGetColumnIndex(column.inputColumnName, out int colIndex); var dstVariableName = ctx.AddIntermediateVariable(_schema[colIndex].Type, column.outputColumnName); - var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); - node.AddAttribute("type", LoaderSignature); + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), ""); } } } diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index a3865695cf..6c794904b7 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1187,6 +1187,39 @@ void MulticlassTrainersOnnxConversionTest() Done(); } + [Fact] + void CopyColumnsOnnxTest() + { + var mlContext = new MLContext(seed: 1); + + var trainDataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); + var dataView = mlContext.Data.LoadFromTextFile(trainDataPath, + separatorChar: ';', + hasHeader: true); + + var pipeline = mlContext.Transforms.CopyColumns("Target1", "Target"); + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + + var onnxFileName = "copycolumns.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + + SaveOnnxModel(onnxModel, onnxModelPath, null); + + if (IsOnnxRuntimeSupported()) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedR4ScalarColumns(model.ColumnPairs[0].outputColumnName, outputNames[2], transformedData, onnxResult); + } + Done(); + } + private void CreateDummyExamplesToMakeComplierHappy() { var dummyExample = new BreastCancerFeatureVector() { Features = null };