Skip to content

Commit faffd17

Browse files
wschincodemzs
authored andcommitted
New ONNX converter interface and some tests (#2013)
* Prototype of new ONNX converter and an end-to-end test * Reuse existing code to do conversion and polish a test * Test Kmeans as well * 1. Introduce ONNX conversion as an extention to MLContext 2. Address minor comments Remove two best friends * Add test comparison * Address comments * Propose another domain name * Add missing header * One more test for one-hot encoding's conversion * Add one more test * Add logistic regression test * Add LightGBM test * Test NaN replacement and fix build for core30 and x86 * Add one more test * Update Kmeans file due to change of initialization * Add a test for onnx conversion cmd * Add word embedding test * Remove old tests * increase tol * Drop version in saved baseline file * make c30 happy * Make c30 happier * c30 needs even more
1 parent 496e185 commit faffd17

File tree

16 files changed

+3076
-1198
lines changed

16 files changed

+3076
-1198
lines changed

src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ namespace Microsoft.ML
1313
/// </summary>
1414
public sealed class ModelOperationsCatalog
1515
{
16+
/// <summary>
17+
/// This is a best friend because an extension method defined in another assembly needs this field.
18+
/// </summary>
19+
[BestFriend]
1620
internal IHostEnvironment Environment { get; }
1721

1822
public ExplainabilityTransforms Explainability { get; }
@@ -33,7 +37,6 @@ protected SubCatalogBase(ModelOperationsCatalog owner)
3337
{
3438
Environment = owner.Environment;
3539
}
36-
3740
}
3841

3942
/// <summary>
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using Microsoft.ML.Core.Data;
7+
using Microsoft.ML.Data;
8+
using Microsoft.ML.Model.Onnx;
9+
using Microsoft.ML.UniversalModelFormat.Onnx;
10+
11+
namespace Microsoft.ML
12+
{
13+
public static class OnnxExportExtensions
14+
{
15+
/// <summary>
16+
/// Convert the specified <see cref="ITransformer"/> to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object.
17+
/// </summary>
18+
/// <param name="catalog">The class that <see cref="ConvertToOnnx(ModelOperationsCatalog, ITransformer, IDataView)"/> attached to.</param>
19+
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
20+
/// <param name="inputData">The input of the specified transform.</param>
21+
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
22+
public static ModelProto ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData)
23+
{
24+
var env = catalog.Environment;
25+
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable);
26+
var outputData = transform.Transform(inputData);
27+
LinkedList<ITransformCanSaveOnnx> transforms = null;
28+
using (var ch = env.Start("ONNX conversion"))
29+
{
30+
SaveOnnxCommand.GetPipe(ctx, ch, outputData, out IDataView root, out IDataView sink, out transforms);
31+
return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, null);
32+
}
33+
}
34+
}
35+
}

src/Microsoft.ML.Onnx/SaveOnnxCommand.cs

Lines changed: 52 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using Microsoft.ML.EntryPoints;
1313
using Microsoft.ML.Internal.Utilities;
1414
using Microsoft.ML.Model.Onnx;
15+
using Microsoft.ML.UniversalModelFormat.Onnx;
1516
using Newtonsoft.Json;
1617

1718
[assembly: LoadableClass(SaveOnnxCommand.Summary, typeof(SaveOnnxCommand), typeof(SaveOnnxCommand.Arguments), typeof(SignatureCommand),
@@ -113,9 +114,10 @@ public override void Run()
113114
}
114115
}
115116

116-
private void GetPipe(OnnxContextImpl ctx, IChannel ch, IDataView end, out IDataView source, out IDataView trueEnd, out LinkedList<ITransformCanSaveOnnx> transforms)
117+
internal static void GetPipe(OnnxContextImpl ctx, IChannel ch, IDataView end, out IDataView source, out IDataView trueEnd, out LinkedList<ITransformCanSaveOnnx> transforms)
117118
{
118-
Host.AssertValue(end);
119+
ch.AssertValue(end);
120+
119121
source = trueEnd = (end as CompositeDataLoader)?.View ?? end;
120122
IDataTransform transform = source as IDataTransform;
121123
transforms = new LinkedList<ITransformCanSaveOnnx>();
@@ -134,7 +136,53 @@ private void GetPipe(OnnxContextImpl ctx, IChannel ch, IDataView end, out IDataV
134136
transform = (source = transform.Source) as IDataTransform;
135137
}
136138

137-
Host.AssertValue(source);
139+
ch.AssertValue(source);
140+
}
141+
142+
internal static ModelProto ConvertTransformListToOnnxModel(OnnxContextImpl ctx, IChannel ch, IDataView inputData, IDataView outputData,
143+
LinkedList<ITransformCanSaveOnnx> transforms, HashSet<string> inputColumnNamesToDrop=null, HashSet<string> outputColumnNamesToDrop=null)
144+
{
145+
inputColumnNamesToDrop = inputColumnNamesToDrop ?? new HashSet<string>();
146+
outputColumnNamesToDrop = outputColumnNamesToDrop ?? new HashSet<string>();
147+
HashSet<string> inputColumns = new HashSet<string>();
148+
// Create graph inputs.
149+
for (int i = 0; i < inputData.Schema.Count; i++)
150+
{
151+
string colName = inputData.Schema[i].Name;
152+
if(inputColumnNamesToDrop.Contains(colName))
153+
continue;
154+
155+
ctx.AddInputVariable(inputData.Schema[i].Type, colName);
156+
inputColumns.Add(colName);
157+
}
158+
159+
// Create graph nodes, outputs and intermediate values.
160+
foreach (var trans in transforms)
161+
{
162+
ch.Assert(trans.CanSaveOnnx(ctx));
163+
trans.SaveAsOnnx(ctx);
164+
}
165+
166+
// Add graph outputs.
167+
for (int i = 0; i < outputData.Schema.Count; ++i)
168+
{
169+
if (outputData.Schema[i].IsHidden)
170+
continue;
171+
172+
var idataviewColumnName = outputData.Schema[i].Name;
173+
174+
// Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
175+
// _inputToDrop should be removed too.
176+
if (inputColumnNamesToDrop.Contains(idataviewColumnName) || outputColumnNamesToDrop.Contains(idataviewColumnName))
177+
continue;
178+
179+
var variableName = ctx.TryGetVariableName(idataviewColumnName);
180+
var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName, true);
181+
ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
182+
ctx.AddOutputVariable(outputData.Schema[i].Type, trueVariableName);
183+
}
184+
185+
return ctx.MakeModel();
138186
}
139187

140188
private void Run(IChannel ch)
@@ -210,45 +258,8 @@ private void Run(IChannel ch)
210258
nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present.");
211259
}
212260

213-
HashSet<string> inputColumns = new HashSet<string>();
214-
//Create graph inputs.
215-
for (int i = 0; i < source.Schema.Count; i++)
216-
{
217-
string colName = source.Schema[i].Name;
218-
if(_inputsToDrop.Contains(colName))
219-
continue;
220-
221-
ctx.AddInputVariable(source.Schema[i].Type, colName);
222-
inputColumns.Add(colName);
223-
}
224-
225-
//Create graph nodes, outputs and intermediate values.
226-
foreach (var trans in transforms)
227-
{
228-
Host.Assert(trans.CanSaveOnnx(ctx));
229-
trans.SaveAsOnnx(ctx);
230-
}
231-
232-
//Add graph outputs.
233-
for (int i = 0; i < end.Schema.Count; ++i)
234-
{
235-
if (end.Schema[i].IsHidden)
236-
continue;
237-
238-
var idataviewColumnName = end.Schema[i].Name;
239-
240-
// Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
241-
// _inputToDrop should be removed too.
242-
if (_inputsToDrop.Contains(idataviewColumnName) || _outputsToDrop.Contains(idataviewColumnName))
243-
continue;
244-
245-
var variableName = ctx.TryGetVariableName(idataviewColumnName);
246-
var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName, true);
247-
ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
248-
ctx.AddOutputVariable(end.Schema[i].Type, trueVariableName);
249-
}
261+
var model = ConvertTransformListToOnnxModel(ctx, ch, source, end, transforms, _inputsToDrop, _outputsToDrop);
250262

251-
var model = ctx.MakeModel();
252263
using (var file = Host.CreateOutputFile(_outputModelPath))
253264
using (var stream = file.CreateWriteStream())
254265
model.WriteTo(stream);

0 commit comments

Comments
 (0)