Skip to content

Commit

Permalink
Fixed model saving and loading of OneVersusAllTrainer to include Soft…
Browse files Browse the repository at this point in the history
…Max (#4472)

* Fixed model saving and loading of OneVersusAllTrainer to include SoftMax

* Modified existing test to include SoftMax option

* Modified test to verify both cases: when UseSoftmax is true and false
  • Loading branch information
harishsk committed Nov 14, 2019
1 parent d45cc8a commit f96761b
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -364,23 +364,29 @@ private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, RegistrationName, ctx)
{
// *** Binary format ***
// bool: useDist
// byte: OutputFormula as byte
// int: predictor count
bool useDist = ctx.Reader.ReadBoolByte();
OutputFormula outputFormula = (OutputFormula)ctx.Reader.ReadByte();
int len = ctx.Reader.ReadInt32();
Host.CheckDecode(len > 0);

if (useDist)
if (outputFormula == OutputFormula.Raw)
{
var predictors = new TScalarPredictor[len];
LoadPredictors(Host, predictors, ctx);
_impl = new ImplRaw(predictors);
}
else if (outputFormula == OutputFormula.ProbabilityNormalization)
{
var predictors = new IValueMapperDist[len];
LoadPredictors(Host, predictors, ctx);
_impl = new ImplDist(predictors);
}
else
else if (outputFormula == OutputFormula.Softmax)
{
var predictors = new TScalarPredictor[len];
LoadPredictors(Host, predictors, ctx);
_impl = new ImplRaw(predictors);
_impl = new ImplSoftmax(predictors);
}

DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length);
Expand Down Expand Up @@ -409,9 +415,10 @@ private protected override void SaveCore(ModelSaveContext ctx)
var preds = _impl.Predictors;

// *** Binary format ***
// bool: useDist
// byte: _impl.OutputFormula as byte
// int: predictor count
ctx.Writer.WriteBoolByte(_impl is ImplDist);
byte[] outputFormula = { (byte)_impl.OutputFormula };
ctx.Writer.WriteBytesNoCount(outputFormula, 1);
ctx.Writer.Write(preds.Length);

// Save other streams.
Expand Down Expand Up @@ -485,6 +492,7 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)

private abstract class ImplBase : ISingleCanSavePfa
{
public OutputFormula OutputFormula;
public abstract DataViewType InputType { get; }
public abstract IValueMapper[] Predictors { get; }
public abstract bool CanSavePfa { get; }
Expand Down Expand Up @@ -536,6 +544,7 @@ internal ImplRaw(TScalarPredictor[] predictors)
CanSavePfa = Predictors.All(m => (m as ISingleCanSavePfa)?.CanSavePfa == true);
Contracts.AssertValue(inputType);
InputType = inputType;
OutputFormula = OutputFormula.Raw;
}

public override ValueMapper<VBuffer<float>, VBuffer<float>> GetMapper()
Expand Down Expand Up @@ -601,6 +610,7 @@ internal ImplDist(IValueMapperDist[] predictors)
CanSavePfa = Predictors.All(m => (m as IDistCanSavePfa)?.CanSavePfa == true);
Contracts.AssertValue(inputType);
InputType = inputType;
OutputFormula = OutputFormula.ProbabilityNormalization;
}

private bool IsValid(IValueMapperDist mapper, ref VectorDataViewType inputType)
Expand Down Expand Up @@ -712,6 +722,7 @@ internal ImplSoftmax(TScalarPredictor[] predictors)
CanSavePfa = false;
Contracts.AssertValue(inputType);
InputType = inputType;
OutputFormula = OutputFormula.Softmax;
}

public override ValueMapper<VBuffer<float>, VBuffer<float>> GetMapper()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Runtime.InteropServices;
using Microsoft.ML.TestFrameworkCommon.Attributes;

namespace Microsoft.ML.TestFramework.Attributes
{
/// <summary>
/// A theory for tests requiring LightGBM.
/// </summary>
public sealed class LightGBMTheoryAttribute : EnvironmentSpecificTheoryAttribute
{
public LightGBMTheoryAttribute() : base("LightGBM is 64-bit only")
{
}

/// <inheritdoc />
protected override bool IsEnvironmentSupported()
{
return Environment.Is64BitProcess;
}
}
}
9 changes: 6 additions & 3 deletions test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,15 @@ public void LightGbmMulticlassEstimator()
/// <summary>
/// LightGbmMulticlass TrainerEstimator test with options
/// </summary>
[LightGBMFact]
public void LightGbmMulticlassEstimatorWithOptions()
[LightGBMTheory]
[InlineData(true)]
[InlineData(false)]
public void LightGbmMulticlassEstimatorWithOptions(bool softMax)
{
var options = new LightGbmMulticlassTrainer.Options
{
EvaluationMetric = LightGbmMulticlassTrainer.Options.EvaluateMetricType.Default
EvaluationMetric = LightGbmMulticlassTrainer.Options.EvaluateMetricType.Default,
UseSoftmax = softMax
};

var (pipeline, dataView) = GetMulticlassPipeline();
Expand Down

0 comments on commit f96761b

Please sign in to comment.