From f96761b3ca55f1ab19584b7195bc19d800e41248 Mon Sep 17 00:00:00 2001 From: Harish Kulkarni Date: Thu, 14 Nov 2019 03:17:12 +0000 Subject: [PATCH] Fixed model saving and loading of OneVersusAllTrainer to include SoftMax (#4472) * Fixed model saving and loading of OneVersusAllTrainer to include SoftMax * Modified existing test to include SoftMax option * Modified test to verify both cases: when UseSoftmax is true and false --- .../OneVersusAllTrainer.cs | 25 +++++++++++++----- .../Attributes/LightGBMTheoryAttribute.cs | 26 +++++++++++++++++++ .../TrainerEstimators/TreeEstimators.cs | 9 ++++--- 3 files changed, 50 insertions(+), 10 deletions(-) create mode 100644 test/Microsoft.ML.TestFramework/Attributes/LightGBMTheoryAttribute.cs diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 2ae05af908..7ab0482acf 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -364,23 +364,29 @@ private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { // *** Binary format *** - // bool: useDist + // byte: OutputFormula as byte // int: predictor count - bool useDist = ctx.Reader.ReadBoolByte(); + OutputFormula outputFormula = (OutputFormula)ctx.Reader.ReadByte(); int len = ctx.Reader.ReadInt32(); Host.CheckDecode(len > 0); - if (useDist) + if (outputFormula == OutputFormula.Raw) + { + var predictors = new TScalarPredictor[len]; + LoadPredictors(Host, predictors, ctx); + _impl = new ImplRaw(predictors); + } + else if (outputFormula == OutputFormula.ProbabilityNormalization) { var predictors = new IValueMapperDist[len]; LoadPredictors(Host, predictors, ctx); _impl = new ImplDist(predictors); } - else + else if (outputFormula == OutputFormula.Softmax) { var predictors = new TScalarPredictor[len]; LoadPredictors(Host, predictors, ctx); - _impl = new ImplRaw(predictors); + _impl = new ImplSoftmax(predictors); } DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); @@ -409,9 +415,10 @@ private protected override void SaveCore(ModelSaveContext ctx) var preds = _impl.Predictors; // *** Binary format *** - // bool: useDist + // byte: _impl.OutputFormula as byte // int: predictor count - ctx.Writer.WriteBoolByte(_impl is ImplDist); + byte[] outputFormula = { (byte)_impl.OutputFormula }; + ctx.Writer.WriteBytesNoCount(outputFormula, 1); ctx.Writer.Write(preds.Length); // Save other streams. @@ -485,6 +492,7 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) private abstract class ImplBase : ISingleCanSavePfa { + public OutputFormula OutputFormula; public abstract DataViewType InputType { get; } public abstract IValueMapper[] Predictors { get; } public abstract bool CanSavePfa { get; } @@ -536,6 +544,7 @@ internal ImplRaw(TScalarPredictor[] predictors) CanSavePfa = Predictors.All(m => (m as ISingleCanSavePfa)?.CanSavePfa == true); Contracts.AssertValue(inputType); InputType = inputType; + OutputFormula = OutputFormula.Raw; } public override ValueMapper, VBuffer> GetMapper() @@ -601,6 +610,7 @@ internal ImplDist(IValueMapperDist[] predictors) CanSavePfa = Predictors.All(m => (m as IDistCanSavePfa)?.CanSavePfa == true); Contracts.AssertValue(inputType); InputType = inputType; + OutputFormula = OutputFormula.ProbabilityNormalization; } private bool IsValid(IValueMapperDist mapper, ref VectorDataViewType inputType) @@ -712,6 +722,7 @@ internal ImplSoftmax(TScalarPredictor[] predictors) CanSavePfa = false; Contracts.AssertValue(inputType); InputType = inputType; + OutputFormula = OutputFormula.Softmax; } public override ValueMapper, VBuffer> GetMapper() diff --git a/test/Microsoft.ML.TestFramework/Attributes/LightGBMTheoryAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/LightGBMTheoryAttribute.cs new file mode 100644 index 0000000000..6b40383863 --- /dev/null +++ b/test/Microsoft.ML.TestFramework/Attributes/LightGBMTheoryAttribute.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; +using Microsoft.ML.TestFrameworkCommon.Attributes; + +namespace Microsoft.ML.TestFramework.Attributes +{ + /// + /// A theory for tests requiring LightGBM. + /// + public sealed class LightGBMTheoryAttribute : EnvironmentSpecificTheoryAttribute + { + public LightGBMTheoryAttribute() : base("LightGBM is 64-bit only") + { + } + + /// + protected override bool IsEnvironmentSupported() + { + return Environment.Is64BitProcess; + } + } +} diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index 1c719a6f85..3e6d9e2769 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -303,12 +303,15 @@ public void LightGbmMulticlassEstimator() /// /// LightGbmMulticlass TrainerEstimator test with options /// - [LightGBMFact] - public void LightGbmMulticlassEstimatorWithOptions() + [LightGBMTheory] + [InlineData(true)] + [InlineData(false)] + public void LightGbmMulticlassEstimatorWithOptions(bool softMax) { var options = new LightGbmMulticlassTrainer.Options { - EvaluationMetric = LightGbmMulticlassTrainer.Options.EvaluateMetricType.Default + EvaluationMetric = LightGbmMulticlassTrainer.Options.EvaluateMetricType.Default, + UseSoftmax = softMax }; var (pipeline, dataView) = GetMulticlassPipeline();