Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added multiple related fixes to enable automatic addition of KeyToValue #4878

Merged
merged 2 commits into from
Feb 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -199,23 +199,26 @@ private protected override void SaveAsOnnxCore(OnnxContext ctx)
for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo)
outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo));

string opType = "Binarizer";
string scoreColumn = Bindings.RowMapper.OutputSchema[Bindings.ScoreColumnIndex].Name;

OnnxNode node;
var binarizerOutput = ctx.AddIntermediateVariable(null, "BinarizerOutput", true);
string opType = "Binarizer";
var binarizerOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "BinarizerOutput", false);
Copy link
Member

@ganik ganik Feb 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

false [](start = 108, length = 5)

What does this false mean?
#Resolved

Copy link
Contributor Author

@harishsk harishsk Feb 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

False = Do not skip adding shape and type information (or in other words, add shape and type information). False is the default value. Technically it is not necessary to specify it.


In reply to: 383088227 [](ancestors = 383088227)

node = ctx.CreateNode(opType, ctx.GetVariableName(scoreColumn), binarizerOutput, ctx.GetNodeName(opType));
node.AddAttribute("threshold", _threshold);

string scoreColumn;
if (Bindings.RowMapper.OutputSchema[Bindings.ScoreColumnIndex].Name == "Score")
scoreColumn = outColumnNames[1];
else
string comparisonOutput = binarizerOutput;
if (Bindings.PredColType is KeyDataViewType)
{
Host.Assert(Bindings.InfoCount >= 3);
scoreColumn = outColumnNames[2];
var one = ctx.AddInitializer(1.0f, "one");
Copy link
Member

@ganik ganik Feb 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.AddInitializer(1.0f, "one"); [](start = 26, length = 32)

Is this +1 ? to be 1 based? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is to make it one based and make it consistent with ML.NET results.


In reply to: 383088629 [](ancestors = 383088629)

var addOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "Add", false);
opType = "Add";
ctx.CreateNode(opType, new[] { binarizerOutput, one }, new[] { addOutput }, ctx.GetNodeName(opType), "");
comparisonOutput = addOutput;
}
node = ctx.CreateNode(opType, ctx.GetVariableName(scoreColumn), binarizerOutput, ctx.GetNodeName(opType));
node.AddAttribute("threshold", _threshold);

opType = "Cast";
node = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), "");
node = ctx.CreateNode(opType, comparisonOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), "");
var predictedLabelCol = OutputSchema.GetColumnOrNull(outColumnNames[0]);
Host.Assert(predictedLabelCol.HasValue);
node.AddAttribute("to", predictedLabelCol.Value.Type.RawType);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ private protected virtual void SaveAsOnnxCore(OnnxContext ctx)
{
int colIndex = Bindings.MapIinfoToCol(iinfo);
string colName = Bindings.GetColumnName(colIndex);
colName = ctx.AddIntermediateVariable(Bindings.GetColumnType(colIndex), colName, true);
colName = ctx.AddIntermediateVariable(Bindings.GetColumnType(colIndex), colName, false);
outVariableNames[iinfo] = colName;
}

Expand Down
7 changes: 3 additions & 4 deletions src/Microsoft.ML.Data/Transforms/KeyToValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string ds
// Onnx expects the input keys to be int64s. But the input data can come from an ML.NET node that
// may output a uint32. So cast it here to ensure that the data is treated correctly
opType = "Cast";
var castNodeOutput = ctx.AddIntermediateVariable(TypeOutput, "CastNodeOutput", true);
var castNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "CastNodeOutput", true);
var castNode = ctx.CreateNode(opType, srcVariableName, castNodeOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Int64).ToType();
castNode.AddAttribute("to", t);
Expand Down Expand Up @@ -568,12 +568,11 @@ public void SaveAsOnnx(OnnxContext ctx)

if (!ctx.ContainsColumn(inputColumnName))
continue;

string srcVariableName = ctx.GetVariableName(inputColumnName);
var dstVariableName = ctx.AddIntermediateVariable(_types[iinfo], info.outputColumnName, true);
var dstVariableName = ctx.AddIntermediateVariable(_types[iinfo], info.outputColumnName);
if (!_kvMaps[iinfo].SaveOnnx(ctx, srcVariableName, dstVariableName))
{
ctx.RemoveColumn(inputColumnName, true);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
long[] termIds;
string opType = "LabelEncoder";
OnnxNode castNode;
var labelEncoderOutput = ctx.AddIntermediateVariable(_types[iinfo], "LabelEncoderOutput", true);
var labelEncoderOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "LabelEncoderOutput", true);

if (info.TypeSrc.GetItemType().Equals(TextDataViewType.Instance))
{
Expand All @@ -804,7 +804,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Double))
{
// LabelEncoder doesn't support double tensors, so values are cast to floats
var castOutput = ctx.AddIntermediateVariable(null, "castOutput", true);
Copy link
Member

@ganik ganik Feb 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

null [](start = 65, length = 4)

How did it work with null before? Was there an exception? #Resolved

Copy link
Contributor Author

@harishsk harishsk Feb 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the last parameter is true, it skips adding shape and type information and therefore accepts null. This still works, but I am prepping some parts of the code base for issues I have seen when run against the master branch of ORT.
There will be an exception if you specify null and set the last parameter to false.


In reply to: 383088824 [](ancestors = 383088824)

var castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "castOutput", true);
castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
castNode.AddAttribute("to", t);
Expand All @@ -815,7 +815,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Int64))
{
// LabelEncoder doesn't support mapping int64 -> int64, so values are cast to strings
var castOutput = ctx.AddIntermediateVariable(null, "castOutput", true);
var castOutput = ctx.AddIntermediateVariable(TextDataViewType.Instance, "castOutput", true);
castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType();
castNode.AddAttribute("to", t);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ internal static ModelProto ConvertTransformListToOnnxModel(OnnxContextImpl ctx,
ch.Check(variableName != null, "The targeted pipeline can not be fully converted into a well-defined ONNX model. " +
"Please check if all steps in that pipeline are convertible to ONNX " +
"and all necessary variables are not dropped (via command line arguments).");
var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName + ".output", true);
var trueVariableName = ctx.AddIntermediateVariable(outputData.Schema[i].Type, idataviewColumnName + ".output");
ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
ctx.AddOutputVariable(outputData.Schema[i].Type, trueVariableName);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -574,13 +574,14 @@ public void SaveAsOnnxPostProcess(OnnxContext ctx, string inputName, string[] ou

string opType;
opType = "ArgMax";
var argMaxOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ArgMaxOutput", true);
var argMaxOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "ArgMaxOutput");
var argMaxNode = ctx.CreateNode(opType, inputName, argMaxOutput, ctx.GetNodeName(opType), "");
argMaxNode.AddAttribute("keepdims", 0);
argMaxNode.AddAttribute("keepdims", 1);
argMaxNode.AddAttribute("axis", 1);

opType = "Add";
var one = ctx.AddInitializer(1);
var addOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "AddOutput", true);
var addOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "AddOutput");
var addNode = ctx.CreateNode(opType, new[] { argMaxOutput, one }, new[] { addOutput }, ctx.GetNodeName(opType), "");

opType = "Cast";
Expand Down Expand Up @@ -662,9 +663,10 @@ public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string fe
var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, false);

string opType = "Concat";
var concatOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ConcatOutput", true);
var type = new VectorDataViewType(NumberDataViewType.Single, probabilityOutputs.Length);
var concatOutput = ctx.AddIntermediateVariable(type, "ConcatOutputRaw");
var concatNode = ctx.CreateNode(opType, probabilityOutputs, new[] { concatOutput }, ctx.GetNodeName(opType), "");
concatNode.AddAttribute("axis", 0);
concatNode.AddAttribute("axis", 1);

base.SaveAsOnnxPostProcess(ctx, concatOutput, outputNames);

Expand Down Expand Up @@ -793,42 +795,43 @@ public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string fe
var probabilityOutputs = base.SaveAsOnnxPreProcess(ctx, featureColumn, true);

opType = "Sum";
var sumOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOfScores", true);
var sumOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOfScores");
ctx.CreateNode(opType, probabilityOutputs, new[] { sumOutput }, ctx.GetNodeName(opType), "");

opType = "Cast";
var castOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "CastOutput", true);
var castOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "CastOutput");
var castNode = ctx.CreateNode(opType, sumOutput, castOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType();
castNode.AddAttribute("to", t);

opType = "Not";
var notOutput = ctx.AddIntermediateVariable(null, "IsSumZero", true);
var notOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "IsSumZero");
ctx.CreateNode(opType, castOutput, notOutput, ctx.GetNodeName(opType), "");

opType = "Cast";
var castIsZeroSumToFloat = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "IsSumZeroAsFloat", true);
var castIsZeroSumToFloat = ctx.AddIntermediateVariable(NumberDataViewType.Single, "IsSumZeroAsFloat");
var castIsZeroSumToFloatNode = ctx.CreateNode(opType, notOutput, castIsZeroSumToFloat, ctx.GetNodeName(opType), "");
var t1 = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
castIsZeroSumToFloatNode.AddAttribute("to", t1);

opType = "Sum";
var sumOutputNonZero = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOfScoresNonZero", true);
var sumOutputNonZero = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOfScoresNonZero");
ctx.CreateNode(opType, new[] { sumOutput, castIsZeroSumToFloat },
new[] { sumOutputNonZero }, ctx.GetNodeName(opType), "");

string[] divOutputs = new string[Predictors.Length];
for (int i = 0; i < Predictors.Length; i++)
{
opType = "Div";
divOutputs[i] = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"DivOutput_{i}", true);
divOutputs[i] = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"DivOutput_{i}");
ctx.CreateNode(opType, new[] { probabilityOutputs[i], sumOutputNonZero }, new[] { divOutputs[i] }, ctx.GetNodeName(opType), "");
}

opType = "Concat";
var concatOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ConcatOutput", true);
var type = new VectorDataViewType(NumberDataViewType.Single, divOutputs.Length);
var concatOutput = ctx.AddIntermediateVariable(type, "ConcatOutputDist");
var concatNode = ctx.CreateNode(opType, divOutputs, new[] { concatOutput }, ctx.GetNodeName(opType), "");
concatNode.AddAttribute("axis", 0);
concatNode.AddAttribute("axis", 1);

base.SaveAsOnnxPostProcess(ctx, concatOutput, outputNames);

Expand Down Expand Up @@ -912,21 +915,22 @@ public override bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string fe

string opType;
opType = "Concat";
var concatOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ConcatOutput", true);
var type = new VectorDataViewType(NumberDataViewType.Single, probabilityOutputs.Length);
var concatOutput = ctx.AddIntermediateVariable(type, "ConcatOutputSoftMax");
var concatNode = ctx.CreateNode(opType, probabilityOutputs, new[] { concatOutput }, ctx.GetNodeName(opType), "");
concatNode.AddAttribute("axis", 0);
concatNode.AddAttribute("axis", 1);

opType = "Exp";
var expOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "ExpOutput", true);
var expOutput = ctx.AddIntermediateVariable(type, "ExpOutput");
var expNode = ctx.CreateNode(opType, concatOutput, expOutput, ctx.GetNodeName(opType), "");

opType = "ReduceSum";
var sumOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOutput", true);
var sumOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "SumOutput");
var sumNode = ctx.CreateNode(opType, expOutput, sumOutput, ctx.GetNodeName(opType), "");
sumNode.AddAttribute("keepdims", 0);
sumNode.AddAttribute("keepdims", 1);

opType = "Div";
var divOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "DivOutput", true);
var divOutput = ctx.AddIntermediateVariable(type, "DivOutput");
var divNode = ctx.CreateNode(opType, new[] { expOutput, sumOutput }, new[] { divOutput }, ctx.GetNodeName(opType), "");

base.SaveAsOnnxPostProcess(ctx, divOutput, outputNames);
Expand Down
24 changes: 9 additions & 15 deletions src/Microsoft.ML.Transforms/OptionalColumnTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -511,21 +511,15 @@ public void SaveAsOnnx(OnnxContext ctx)
if (!ctx.ContainsColumn(inputColumnName))
continue;

if (!SaveAsOnnxCore(ctx, iinfo, ctx.GetVariableName(inputColumnName),
ctx.AddIntermediateVariable(OutputSchema[_bindings.MapIinfoToCol(iinfo)].Type, inputColumnName)))
{
if (!SaveAsOnnxCore(ctx, ctx.GetVariableName(inputColumnName), _bindings.ColumnTypes[iinfo]))
ctx.RemoveColumn(inputColumnName, true);
}
}
}

public bool CanSaveOnnx(OnnxContext ctx) => true;

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, DataViewType columnType)
{
var columnType = _bindings.ColumnTypes[iinfo];
string inputColumnName = Source.Schema[_bindings.SrcCols[iinfo]].Name;

Type type = columnType.RawType;

int size;
Expand All @@ -537,24 +531,24 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName,
if ((type == typeof(int)) ||
(type == typeof(short)) || (type == typeof(ushort)) ||
(type == typeof(sbyte)) || (type == typeof(byte)))
ctx.AddInitializer(new int[size], type, new long[] { 1, size }, inputColumnName, false);
ctx.AddInitializer(new int[size], type, new long[] { 1, size }, srcVariableName, false);
else if (type == typeof(uint) || (type == typeof(ulong)))
ctx.AddInitializer(new ulong[size], type == typeof(ulong), new long[] { 1, size }, inputColumnName, false);
ctx.AddInitializer(new ulong[size], type == typeof(ulong), new long[] { 1, size }, srcVariableName, false);
else if (type == typeof(bool))
ctx.AddInitializer(new bool[size], new long[] { 1, size }, inputColumnName, false);
ctx.AddInitializer(new bool[size], new long[] { 1, size }, srcVariableName, false);
else if (type == typeof(long))
ctx.AddInitializer(new long[size], new long[] { 1, size }, inputColumnName, false);
ctx.AddInitializer(new long[size], new long[] { 1, size }, srcVariableName, false);
else if (type == typeof(float))
ctx.AddInitializer(new float[size], new long[] { 1, size }, inputColumnName, false);
ctx.AddInitializer(new float[size], new long[] { 1, size }, srcVariableName, false);
else if (type == typeof(double))
ctx.AddInitializer(new double[size], new long[] { 1, size }, inputColumnName, false);
ctx.AddInitializer(new double[size], new long[] { 1, size }, srcVariableName, false);
else if ((type == typeof(string)) || (columnType is TextDataViewType))
{
string[] values = new string[size];
for (int i = 0; i < size; i++)
values[i] = "";

ctx.AddInitializer(values, new long[] { 1, size }, inputColumnName, false);
ctx.AddInitializer(values, new long[] { 1, size }, srcVariableName, false);
}
else
return false;
Expand Down
Loading