Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightGBM Unbalanced Data Argument [Issue #3688 Fix] #3925

Merged
merged 12 commits into from
Jul 1, 2019
1 change: 0 additions & 1 deletion src/Microsoft.ML.LightGbm/LightGbmArguments.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ public BoosterParameterBase(OptionsBase options)
public abstract class OptionsBase : IBoosterParameterFactory
{
internal BoosterParameterBase GetBooster() { return null; }

/// <summary>
/// The minimum loss reduction required to make a further partition on a leaf node of the tree.
/// </summary>
Expand Down
7 changes: 7 additions & 0 deletions src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ public enum EvaluateMetricType
LogLoss,
}

/// <summary>
/// Whether training data is unbalanced.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Use for multi-class classification when training data is not balanced", ShortName = "us")]
public bool UnbalancedSets = false;

/// <summary>
/// Whether to use softmax loss.
/// </summary>
Expand Down Expand Up @@ -110,6 +116,7 @@ internal override Dictionary<string, object> ToDictionary(IHost host)
{
var res = base.ToDictionary(host);

res[GetOptionName(nameof(UnbalancedSets))] = UnbalancedSets;
Copy link
Member

@codemzs codemzs Jul 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

res[GetOptionName(nameof(UnbalancedSets))] = UnbalancedSets; [](start = 16, length = 60)

how is getting mapped to is_unbalance 🔗︎, default = false, type = bool, aliases: unbalance, unbalanced_sets #Resolved

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nevermind I see it now:

public static string GetOptionName(string name)
{
// Otherwise convert the name to the light gbm argument
StringBuilder strBuf = new StringBuilder();
bool first = true;
foreach (char c in name)
{
if (char.IsUpper(c))
{
if (first)
first = false;
else
strBuf.Append('_');
strBuf.Append(char.ToLower(c));
}
else
strBuf.Append(c);
}
return strBuf.ToString();
}


In reply to: 299023864 [](ancestors = 299023864)

res[GetOptionName(nameof(Sigmoid))] = Sigmoid;
res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString());

Expand Down
12 changes: 12 additions & 0 deletions test/BaselineOutput/Common/EntryPoints/core_manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -11974,6 +11974,18 @@
"IsNullable": false,
"Default": "Auto"
},
{
"Name": "UnbalancedSets",
"Type": "Bool",
"Desc": "Use for multi-class classification when training data is not balanced",
"Aliases": [
"us"
],
"Required": false,
"SortOrder": 150.0,
"IsNullable": false,
"Default": false
},
{
"Name": "UseSoftmax",
"Type": "Bool",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ public void TreeEnsembleFeaturizingPipelineMulticlass()

private class RowWithKey
{
[KeyType()]
[KeyType(4)]
public uint KeyLabel { get; set; }
}

Expand Down
94 changes: 92 additions & 2 deletions test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@ public void LightGBMBinaryEstimator()
NumberOfLeaves = 10,
NumberOfThreads = 1,
MinimumExampleCountPerLeaf = 2,
UnbalancedSets = false, // default value
});

var pipeWithTrainer = pipe.Append(trainer);
TestEstimatorCore(pipeWithTrainer, dataView);

var transformedDataView = pipe.Fit(dataView).Transform(dataView);
var model = trainer.Fit(transformedDataView, transformedDataView);
Done();
}

[LightGBMFact]
public void LightGBMBinaryEstimatorUnbalanced()
{
var (pipe, dataView) = GetBinaryClassificationPipeline();

var trainer = ML.BinaryClassification.Trainers.LightGbm(new LightGbmBinaryTrainer.Options
{
NumberOfLeaves = 10,
NumberOfThreads = 1,
MinimumExampleCountPerLeaf = 2,
UnbalancedSets = true,
});

var pipeWithTrainer = pipe.Append(trainer);
Expand Down Expand Up @@ -322,6 +344,44 @@ public void LightGbmMulticlassEstimatorCorrectSigmoid()
Done();
}

/// <summary>
/// LightGbmMulticlass Test of Balanced Data
/// </summary>
[LightGBMFact]
public void LightGbmMulticlassEstimatorBalanced()
{
var (pipeline, dataView) = GetMulticlassPipeline();

var trainer = ML.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options
{
UnbalancedSets = false
});

var pipe = pipeline.Append(trainer)
.Append(new KeyToValueMappingEstimator(Env, "PredictedLabel"));
TestEstimatorCore(pipe, dataView);
Done();
}

/// <summary>
/// LightGbmMulticlass Test of Unbalanced Data
/// </summary>
[LightGBMFact]
public void LightGbmMulticlassEstimatorUnbalanced()
{
var (pipeline, dataView) = GetMulticlassPipeline();

var trainer = ML.MulticlassClassification.Trainers.LightGbm(new LightGbmMulticlassTrainer.Options
{
UnbalancedSets = true
});

var pipe = pipeline.Append(trainer)
.Append(new KeyToValueMappingEstimator(Env, "PredictedLabel"));
TestEstimatorCore(pipe, dataView);
Done();
}

// Number of examples
private const int _rowNumber = 1000;
// Number of features
Expand All @@ -338,7 +398,7 @@ private class GbmExample
public float[] Score;
}

private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelString, out List<GbmExample> mlnetPredictions, out double[] lgbmRawScores, out double[] lgbmProbabilities)
private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelString, out List<GbmExample> mlnetPredictions, out double[] lgbmRawScores, out double[] lgbmProbabilities, bool unbalancedSets = false)
{
// Prepare data and train LightGBM model via ML.NET
// Training matrix. It contains all feature vectors.
Expand Down Expand Up @@ -372,7 +432,8 @@ private void LightGbmHelper(bool useSoftmax, double sigmoid, out string modelStr
MinimumExampleCountPerGroup = 1,
MinimumExampleCountPerLeaf = 1,
UseSoftmax = useSoftmax,
Sigmoid = sigmoid // Custom sigmoid value.
Sigmoid = sigmoid, // Custom sigmoid value.
UnbalancedSets = unbalancedSets // false by default
});

var gbm = gbmTrainer.Fit(dataView);
Expand Down Expand Up @@ -583,6 +644,35 @@ public void LightGbmMulticlassEstimatorCompareSoftMax()
Done();
}

[LightGBMFact]
public void LightGbmMulticlassEstimatorCompareUnbalanced()
{
// Train ML.NET LightGBM and native LightGBM and apply the trained models to the training set.
LightGbmHelper(useSoftmax: true, sigmoid: .5, out string modelString, out List<GbmExample> mlnetPredictions, out double[] nativeResult1, out double[] nativeResult0, unbalancedSets:true);

// The i-th predictor returned by LightGBM produces the raw score, denoted by z_i, of the i-th class.
// Assume that we have n classes in total. The i-th class probability can be computed via
// p_i = exp(z_i) / (exp(z_1) + ... + exp(z_n)).
Assert.True(modelString != null);
// Compare native LightGBM's and ML.NET's LightGBM results example by example
for (int i = 0; i < _rowNumber; ++i)
{
double sum = 0;
for (int j = 0; j < _classNumber; ++j)
{
Assert.Equal(nativeResult0[j + i * _classNumber], mlnetPredictions[i].Score[j], 6);
sum += Math.Exp((float)nativeResult1[j + i * _classNumber]);
}
for (int j = 0; j < _classNumber; ++j)
{
double prob = Math.Exp(nativeResult1[j + i * _classNumber]);
Assert.Equal(prob / sum, mlnetPredictions[i].Score[j], 6);
}
}

Done();
}

[LightGBMFact]
public void LightGbmInDifferentCulture()
{
Expand Down