From 93388b60b5286b67ad66e6006c3bc9631733a5c0 Mon Sep 17 00:00:00 2001 From: Keren Fuentes Date: Mon, 11 Nov 2019 11:19:59 -0800 Subject: [PATCH] Added extraction of score column before node creation --- .../Scorers/SchemaBindablePredictorWrapper.cs | 2 +- src/Microsoft.ML.FastTree/FastTree.cs | 3 ++- .../Standard/LinearModelParameters.cs | 6 +++--- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 1 - 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index 1367249c480..b7f6c4da871 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -320,7 +320,7 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema if (!ctx.ContainsColumn(featName)) return false; Contracts.Assert(ctx.ContainsColumn(featName)); - return mapper.SaveAsOnnx(ctx, new[] { outputNames[1] }, ctx.GetVariableName(featName)); + return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(featName)); } private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) => diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 82772ff151e..c0bd8e52d2d 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -3111,7 +3111,8 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string } string opType = "TreeEnsembleRegressor"; - var node = ctx.CreateNode(opType, new[] { featureColumn }, outputNames, ctx.GetNodeName(opType)); + string scoreVarName = (Utils.Size(outputNames) == 2) ? outputNames[1] : outputNames[0]; // Get Score from PredictedLabel and/or Score columns + var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType)); node.AddAttribute("post_transform", PostTransform.None.GetDescription()); node.AddAttribute("n_targets", 1); diff --git a/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs b/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs index 27d9ad6f250..9f6c45ae4d6 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs @@ -240,10 +240,10 @@ JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn) { Host.CheckValue(ctx, nameof(ctx)); - Host.Check(Utils.Size(outputs) == 1); - string opType = "LinearRegressor"; - var node = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType)); + string scoreVarName = (Utils.Size(outputs) == 2) ? outputs[1] : outputs[0]; // Get Score from PredictedLabel and/or Score columns + + var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType)); // Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT} node.AddAttribute("post_transform", "NONE"); node.AddAttribute("targets", 1); diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 3cd22a8e025..d30ad47148c 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -240,7 +240,6 @@ public void binaryClassificationTrainersOnnxConversionTest() CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3); CompareSelectedScalarColumns(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult); } - } Done(); }