Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added onnx export support for SelectColumns #4590

Merged
merged 5 commits into from
Jan 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ internal abstract class OnnxContext
/// <param name="removeColumn">IDataView column to stop tracking</param>
public abstract void RemoveVariable(string variableName, bool removeColumn);

/// <summary>
/// Removes a variable from the input columns list. This function is used only by the ColumnSelectingTransformer.
/// </summary>
/// <param name="variableName">ONNX variable to remove. </param>
public abstract void RemoveInputVariable(string variableName);

/// <summary>
/// ONNX variables are referred to by name. At each stage of a ML.NET pipeline, the corresponding
/// <see cref="IDataView"/>'s column names will map to a variable in the ONNX graph if the intermediate steps
Expand Down
39 changes: 33 additions & 6 deletions src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;

Expand Down Expand Up @@ -43,6 +44,7 @@ namespace Microsoft.ML.Transforms
/// | -- | -- |
/// | Does this estimator need to look at the data to train its parameters? | No |
/// | Input columns data type | Any |
/// | Exportable to ONNX | Yes |
///
/// The resulting <xref:Microsoft.ML.Transforms.ColumnSelectingTransformer>
/// operates on the schema of a given <xref:Microsoft.ML.IDataView> by dropping or keeping selected columns from the schema.
Expand Down Expand Up @@ -520,7 +522,7 @@ private sealed class Mapper
{
private readonly IHost _host;
private readonly DataViewSchema _inputSchema;
private readonly int[] _outputToInputMap;
public readonly int[] OutputToInputMap;

public DataViewSchema InputSchema => _inputSchema;

Expand All @@ -531,17 +533,17 @@ public Mapper(ColumnSelectingTransformer transform, DataViewSchema inputSchema)
_host = transform._host.Register(nameof(Mapper));
_inputSchema = inputSchema;

_outputToInputMap = BuildOutputToInputMap(transform.SelectColumns,
OutputToInputMap = BuildOutputToInputMap(transform.SelectColumns,
transform.KeepColumns,
transform.KeepHidden,
_inputSchema);
OutputSchema = GenerateOutputSchema(_outputToInputMap, _inputSchema);
OutputSchema = GenerateOutputSchema(OutputToInputMap, _inputSchema);
}

public int GetInputIndex(int outputIndex)
{
_host.Assert(0 <= outputIndex && outputIndex < _outputToInputMap.Length);
return _outputToInputMap[outputIndex];
_host.Assert(0 <= outputIndex && outputIndex < OutputToInputMap.Length);
return OutputToInputMap[outputIndex];
}

private static int[] BuildOutputToInputMap(IEnumerable<string> selectedColumns,
Expand Down Expand Up @@ -648,7 +650,7 @@ public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column colu
public override bool IsColumnActive(DataViewSchema.Column column) => true;
}

private sealed class SelectColumnsDataTransform : IDataTransform, IRowToRowMapper, ITransformTemplate
private sealed class SelectColumnsDataTransform : IDataTransform, IRowToRowMapper, ITransformTemplate, ITransformCanSaveOnnx
{
private readonly IHost _host;
private readonly ColumnSelectingTransformer _transform;
Expand Down Expand Up @@ -725,6 +727,31 @@ DataViewRow IRowToRowMapper.GetRow(DataViewRow input, IEnumerable<DataViewSchema

IDataTransform ITransformTemplate.ApplyToData(IHostEnvironment env, IDataView newSource)
=> new SelectColumnsDataTransform(env, _transform, new Mapper(_transform, newSource.Schema), newSource);

public bool CanSaveOnnx(OnnxContext ctx) => true;

public void SaveAsOnnx(OnnxContext ctx)
{
var droppedCols = new HashSet<int>(Enumerable.Range(0, InputSchema.Count));

var outputToInputMap = _mapper.OutputToInputMap;
for(int i = 0; i < outputToInputMap.Length; i++)
{
var srcCol = InputSchema[outputToInputMap[i]];
var dstCol = OutputSchema[i];
var srcVariable = ctx.GetVariableName(srcCol.Name);
var dstVariable = ctx.AddIntermediateVariable(dstCol.Type, dstCol.Name, true);
string opType = "Identity";
ctx.CreateNode(opType, srcVariable, dstVariable, ctx.GetNodeName(opType), "");

droppedCols.Remove(srcCol.Index);
}

foreach (var srcCol in droppedCols)
{
ctx.RemoveInputVariable(InputSchema[srcCol].Name);
}
Copy link
Member

@codemzs codemzs Dec 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are you dropping remaining columns? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default, only the columns that have connections are selected. So the columns not in the OutputToInputMap are automatically dropped.

But your comment made me look deeper into the output schema. I have fixed the output schema with an updated to OnnxTransformer.


In reply to: 361511119 [](ancestors = 361511119)

}
}

private sealed class Cursor : SynchronizedCursorBase
Expand Down
9 changes: 9 additions & 0 deletions src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,15 @@ public void AddInputVariable(DataViewType type, string colName)
_inputs.Add(OnnxUtils.GetModelArgs(type, colName));
}

public override void RemoveInputVariable(string colName)
{
var variableName = TryGetVariableName(colName);
_host.CheckValue(variableName, nameof(variableName));

RemoveVariable(variableName, true);
_inputs.Remove(_inputs.Single(modelArg => modelArg.Name == variableName));
}

/// <summary>
/// Retrieve the shape of an ONNX variable. Returns null if no shape for the specified variable can be found.
/// </summary>
Expand Down
248 changes: 248 additions & 0 deletions test/BaselineOutput/Common/Onnx/Transforms/SelectColumns.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
{
"irVersion": "6",
"producerName": "ML.NET",
"producerVersion": "##VERSION##",
"domain": "machinelearning.dotnet",
"graph": {
"node": [
{
"input": [
"Size"
],
"output": [
"Size0"
],
"name": "Identity",
"opType": "Identity"
},
{
"input": [
"Shape"
],
"output": [
"Shape0"
],
"name": "Identity0",
"opType": "Identity"
},
{
"input": [
"Thickness"
],
"output": [
"Thickness0"
],
"name": "Identity1",
"opType": "Identity"
},
{
"input": [
"Label"
],
"output": [
"Label0"
],
"name": "Identity2",
"opType": "Identity"
},
{
"input": [
"Size0"
],
"output": [
"Size1"
],
"name": "Identity3",
"opType": "Identity"
},
{
"input": [
"Shape0"
],
"output": [
"Shape1"
],
"name": "Identity4",
"opType": "Identity"
},
{
"input": [
"Thickness0"
],
"output": [
"Thickness1"
],
"name": "Identity5",
"opType": "Identity"
},
{
"input": [
"Label0"
],
"output": [
"Label1"
],
"name": "Identity6",
"opType": "Identity"
}
],
"name": "model",
"input": [
{
"name": "Label",
"type": {
"tensorType": {
"elemType": 9,
"shape": {
"dim": [
{
"dimValue": "1"
},
{
"dimValue": "1"
}
]
}
}
}
},
{
"name": "Thickness",
"type": {
"tensorType": {
"elemType": 6,
"shape": {
"dim": [
{
"dimValue": "1"
},
{
"dimValue": "1"
}
]
}
}
}
},
{
"name": "Size",
"type": {
"tensorType": {
"elemType": 6,
"shape": {
"dim": [
{
"dimValue": "1"
},
{
"dimValue": "1"
}
]
}
}
}
},
{
"name": "Shape",
"type": {
"tensorType": {
"elemType": 6,
"shape": {
"dim": [
{
"dimValue": "1"
},
{
"dimValue": "1"
}
]
}
}
}
}
],
"output": [
{
"name": "Size1",
"type": {
"tensorType": {
"elemType": 6,
"shape": {
"dim": [
{
"dimValue": "1"
},
{
"dimValue": "1"
}
]
}
}
}
},
{
"name": "Shape1",
"type": {
"tensorType": {
"elemType": 6,
"shape": {
"dim": [
{
"dimValue": "1"
},
{
"dimValue": "1"
}
]
}
}
}
},
{
"name": "Thickness1",
"type": {
"tensorType": {
"elemType": 6,
"shape": {
"dim": [
{
"dimValue": "1"
},
{
"dimValue": "1"
}
]
}
}
}
},
{
"name": "Label1",
"type": {
"tensorType": {
"elemType": 9,
"shape": {
"dim": [
{
"dimValue": "1"
},
{
"dimValue": "1"
}
]
}
}
}
}
]
},
"opsetImport": [
{
"domain": "ai.onnx.ml",
"version": "2"
},
{
"version": "11"
}
]
}
Loading