Skip to content

Commit

Permalink
LightGBM Unbalanced Data Argument [Issue dotnet#3688 Fix] (dotnet#3925)
Browse files Browse the repository at this point in the history
* LightGBM unbalanced data arg added

* unbalanced data argument added

* tests for unbalanced LightGbm and added arg for multiclass

* reverted changes on LightGbmArguments

* wording improvement on unbalanced arg help text

* updated manifest

* removed empty line

* added keytype to test
  • Loading branch information
Rayan-Krishnan authored and Dmitry-A committed Jul 24, 2019
1 parent 2f2d0d0 commit 18bb1cf
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 3 deletions.
1 change: 0 additions & 1 deletion src/Microsoft.ML.LightGbm/LightGbmArguments.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ public BoosterParameterBase(OptionsBase options)
public abstract class OptionsBase : IBoosterParameterFactory
{
internal BoosterParameterBase GetBooster() { return null; }

/// <summary>
/// The minimum loss reduction required to make a further partition on a leaf node of the tree.
/// </summary>
Expand Down
7 changes: 7 additions & 0 deletions src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ public enum EvaluateMetricType
LogLoss,
}

/// <summary>
/// Whether training data is unbalanced.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Use for multi-class classification when training data is not balanced", ShortName = "us")]
public bool UnbalancedSets = false;

/// <summary>
/// Whether to use softmax loss.
/// </summary>
Expand Down Expand Up @@ -110,6 +116,7 @@ internal override Dictionary<string, object> ToDictionary(IHost host)
{
var res = base.ToDictionary(host);

res[GetOptionName(nameof(UnbalancedSets))] = UnbalancedSets;
res[GetOptionName(nameof(Sigmoid))] = Sigmoid;
res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString());

Expand Down
12 changes: 12 additions & 0 deletions test/BaselineOutput/Common/EntryPoints/core_manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -11974,6 +11974,18 @@
"IsNullable": false,
"Default": "Auto"
},
{
"Name": "UnbalancedSets",
"Type": "Bool",
"Desc": "Use for multi-class classification when training data is not balanced",
"Aliases": [
"us"
],
"Required": false,
"SortOrder": 150.0,
"IsNullable": false,
"Default": false
},
{
"Name": "UseSoftmax",
"Type": "Bool",
Expand Down
94 changes: 92 additions & 2 deletions test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@ public void LightGBMBinaryEstimator()
NumberOfLeaves = 10,
NumberOfThreads = 1,
MinimumExampleCountPerLeaf = 2,
UnbalancedSets = false, // default value
});

var pipeWithTrainer = pipe.Append(trainer);
TestEstimatorCore(pipeWithTrainer, dataView);

var transformedDataView = pipe.Fit(dataView).Transform(dataView);
var model = trainer.Fit(transformedDataView, transformedDataView);
Done();
}

[LightGBMFact]
public void LightGBMBinaryEstimatorUnbalanced()
{
var (pipe, dataView) = GetBinaryClassificationPipeline();

var trainer = ML.BinaryClassification.Trainers.LightGbm(new LightGbmBinaryTrainer.Options
{
NumberOfLeaves = 10,
NumberOfThreads = 1,
MinimumExampleCountPerLeaf = 2,
UnbalancedSets = true,
});

var pipeWithTrainer = pipe.Append(trainer);
Expand Down Expand Up @@ -322,6 +344,44 @@ public void LightGbmMulticlassEstimatorCorrectSigmoid()
Done();
}

/// <summary>
/// LightGbmMulticlass Test of Balanced Data
/// </summary>
[LightGBMFact]
public void LightGbmMulticlassEstimatorBalanced()
{
var (pipeline, dataView) = GetMulticlassPipeline();

var trainer = ML.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options
{
UnbalancedSets = false
});

var pipe = pipeline.Append(trainer)
.Append(new KeyToValueMappingEstimator(Env, "PredictedLabel"));
TestEstimatorCore(pipe, dataView);
Done();
}

/// <summary>
/// LightGbmMulticlass Test of Unbalanced Data
/// </summary>
[LightGBMFact]
public void LightGbmMulticlassEstimatorUnbalanced()
{
var (pipeline, dataView) = GetMulticlassPipeline();

var trainer = ML.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options
{
UnbalancedSets = true
});

var pipe = pipeline.Append(trainer)
.Append(new KeyToValueMappingEstimator(Env, "PredictedLabel"));
TestEstimatorCore(pipe, dataView);
Done();
}

// Number of examples
private const int _rowNumber = 1000;
// Number of features
Expand All @@ -338,7 +398,7 @@ private class GbmExample
public float[] Score;
}

private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelString, out List<GbmExample> mlnetPredictions, out double[] lgbmRawScores, out double[] lgbmProbabilities)
private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelString, out List<GbmExample> mlnetPredictions, out double[] lgbmRawScores, out double[] lgbmProbabilities, bool unbalancedSets = false)
{
// Prepare data and train LightGBM model via ML.NET
// Training matrix. It contains all feature vectors.
Expand Down Expand Up @@ -372,7 +432,8 @@ private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelStr
MinimumExampleCountPerGroup = 1,
MinimumExampleCountPerLeaf = 1,
UseSoftmax = useSoftmax,
Sigmoid = sigmoid // Custom sigmoid value.
Sigmoid = sigmoid, // Custom sigmoid value.
UnbalancedSets = unbalancedSets // false by default
});

var gbm = gbmTrainer.Fit(dataView);
Expand Down Expand Up @@ -583,6 +644,35 @@ public void LightGbmMulticlassEstimatorCompareSoftMax()
Done();
}

[LightGBMFact]
public void LightGbmMulticlassEstimatorCompareUnbalanced()
{
// Train ML.NET LightGBM and native LightGBM and apply the trained models to the training set.
LightGbmHelper(useSoftmax: true, sigmoid: .5, out string modelString, out List<GbmExample> mlnetPredictions, out double[] nativeResult1, out double[] nativeResult0, unbalancedSets:true);

// The i-th predictor returned by LightGBM produces the raw score, denoted by z_i, of the i-th class.
// Assume that we have n classes in total. The i-th class probability can be computed via
// p_i = exp(z_i) / (exp(z_1) + ... + exp(z_n)).
Assert.True(modelString != null);
// Compare native LightGBM's and ML.NET's LightGBM results example by example
for (int i = 0; i < _rowNumber; ++i)
{
double sum = 0;
for (int j = 0; j < _classNumber; ++j)
{
Assert.Equal(nativeResult0[j + i * _classNumber], mlnetPredictions[i].Score[j], 6);
sum += Math.Exp((float)nativeResult1[j + i * _classNumber]);
}
for (int j = 0; j < _classNumber; ++j)
{
double prob = Math.Exp(nativeResult1[j + i * _classNumber]);
Assert.Equal(prob / sum, mlnetPredictions[i].Score[j], 6);
}
}

Done();
}

[LightGBMFact]
public void LightGbmInDifferentCulture()
{
Expand Down

0 comments on commit 18bb1cf

Please sign in to comment.