Skip to content

Commit

Permalink
Added extraction of score column before node creation
Browse files Browse the repository at this point in the history
  • Loading branch information
Lynx1820 committed Nov 11, 2019
1 parent ea71828 commit 93388b6
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema
if (!ctx.ContainsColumn(featName))
return false;
Contracts.Assert(ctx.ContainsColumn(featName));
return mapper.SaveAsOnnx(ctx, new[] { outputNames[1] }, ctx.GetVariableName(featName));
return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(featName));
}

private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) =>
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3111,7 +3111,8 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
}

string opType = "TreeEnsembleRegressor";
var node = ctx.CreateNode(opType, new[] { featureColumn }, outputNames, ctx.GetNodeName(opType));
string scoreVarName = (Utils.Size(outputNames) == 2) ? outputNames[1] : outputNames[0]; // Get Score from PredictedLabel and/or Score columns
var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType));

node.AddAttribute("post_transform", PostTransform.None.GetDescription());
node.AddAttribute("n_targets", 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,10 @@ JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input)
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
{
Host.CheckValue(ctx, nameof(ctx));
Host.Check(Utils.Size(outputs) == 1);

string opType = "LinearRegressor";
var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType));
string scoreVarName = (Utils.Size(outputs) == 2) ? outputs[1] : outputs[0]; // Get Score from PredictedLabel and/or Score columns

var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType));
// Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT}
node.AddAttribute("post_transform", "NONE");
node.AddAttribute("targets", 1);
Expand Down
1 change: 0 additions & 1 deletion test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ public void binaryClassificationTrainersOnnxConversionTest()
CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3);
CompareSelectedScalarColumns<Boolean>(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult);
}

}
Done();
}
Expand Down

0 comments on commit 93388b6

Please sign in to comment.