Skip to content

Commit

Permalink
Added onnx export support for CopyColumns (dotnet#4486)
Browse files Browse the repository at this point in the history
  • Loading branch information
harishsk authored Nov 27, 2019
1 parent 9af92a4 commit 549b389
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/Microsoft.ML.Data/Transforms/ColumnCopying.cs
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
private readonly DataViewSchema _schema;
private readonly (string outputColumnName, string inputColumnName)[] _columns;

public bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental;
public bool CanSaveOnnx(OnnxContext ctx) => true;

internal Mapper(ColumnCopyingTransformer parent, DataViewSchema inputSchema, (string outputColumnName, string inputColumnName)[] columns)
: base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
Expand Down Expand Up @@ -233,15 +233,16 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()

public void SaveAsOnnx(OnnxContext ctx)
{
var opType = "CSharp";
var opType = "Identity";

foreach (var column in _columns)
{
var srcVariableName = ctx.GetVariableName(column.inputColumnName);
if (!ctx.ContainsColumn(srcVariableName))
continue;
_schema.TryGetColumnIndex(column.inputColumnName, out int colIndex);
var dstVariableName = ctx.AddIntermediateVariable(_schema[colIndex].Type, column.outputColumnName);
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
node.AddAttribute("type", LoaderSignature);
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
}
}
}
Expand Down
33 changes: 33 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,39 @@ void MulticlassTrainersOnnxConversionTest()
Done();
}

[Fact]
void CopyColumnsOnnxTest()
{
var mlContext = new MLContext(seed: 1);

var trainDataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
var dataView = mlContext.Data.LoadFromTextFile<AdultData>(trainDataPath,
separatorChar: ';',
hasHeader: true);

var pipeline = mlContext.Transforms.CopyColumns("Target1", "Target");
var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

var onnxFileName = "copycolumns.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);

SaveOnnxModel(onnxModel, onnxModelPath, null);

if (IsOnnxRuntimeSupported())
{
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedR4ScalarColumns(model.ColumnPairs[0].outputColumnName, outputNames[2], transformedData, onnxResult);
}
Done();
}

private void CreateDummyExamplesToMakeComplierHappy()
{
var dummyExample = new BreastCancerFeatureVector() { Features = null };
Expand Down

0 comments on commit 549b389

Please sign in to comment.