-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LightGBM Unbalanced Data Argument [Issue #3688 Fix] #3925
LightGBM Unbalanced Data Argument [Issue #3688 Fix] #3925
Conversation
/// Whether training data is unbalanced. Used by <see cref="LightGbmBinaryTrainer"/>. | ||
/// </summary> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Use for binary classification when training data is not balanced.", ShortName = "us")] | ||
public bool UnbalancedSets = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need two tests for UnbalancedSets=true
and UnbalancedSets=false
. In addition to simply run a model, we need more restrictive comparison like
[LightGBMFact]
public void LightGbmMulticlassEstimatorCompareOva()
{
float sigmoidScale = 0.5f; // Constant used train LightGBM. See gbmParams["sigmoid"] in the helper function.
// Train ML.NET LightGBM and native LightGBM and apply the trained models to the training set.
LightGbmHelper(useSoftmax: false, sigmoid: sigmoidScale, out string modelString, out List<GbmExample> mlnetPredictions, out double[] nativeResult1, out double[] nativeResult0);
// The i-th predictor returned by LightGBM produces the raw score, denoted by z_i, of the i-th class.
// Assume that we have n classes in total. The i-th class probability can be computed via
// p_i = sigmoid(sigmoidScale * z_i) / (sigmoid(sigmoidScale * z_1) + ... + sigmoid(sigmoidScale * z_n)).
Assert.True(modelString != null);
// Compare native LightGBM's and ML.NET's LightGBM results example by example
for (int i = 0; i < _rowNumber; ++i)
{
double sum = 0;
for (int j = 0; j < _classNumber; ++j)
{
Assert.Equal(nativeResult0[j + i * _classNumber], mlnetPredictions[i].Score[j], 6);
if (float.IsNaN((float)nativeResult1[j + i * _classNumber]))
continue;
sum += MathUtils.SigmoidSlow(sigmoidScale * (float)nativeResult1[j + i * _classNumber]);
}
for (int j = 0; j < _classNumber; ++j)
{
double prob = MathUtils.SigmoidSlow(sigmoidScale * (float)nativeResult1[j + i * _classNumber]);
Assert.Equal(prob / sum, mlnetPredictions[i].Score[j], 6);
}
}
Done();
}
in machinelearning2\test\Microsoft.ML.Tests\TrainerEstimators\TreeEstimators.cs
. Note that this test directly checks if ML.NET's LightGBM module and LightGBM's C API produce identical numbers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. I will add and run these tests
Needed to add the unbalanced argument to LightGbmMulticlassTrainer. Ran a test to revise manifest file Tests Added:
|
…to LightGBM-unbalanced-data
/// <summary> | ||
/// Whether training data is unbalanced. | ||
/// </summary> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Use for binary classification when training data is not balanced.", ShortName = "us")] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use for binary classification when training data is not balanced.
It's odd to say "use for binary classification" when this is the multiclass trainer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it more helpful to just say "Use for multi-class classification when training data is not balanced."
/// <summary> | ||
/// Whether training data is unbalanced. Used by <see cref="LightGbmBinaryTrainer"/>. | ||
/// </summary> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Use for binary classification when training data is not balanced.", ShortName = "us")] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use for binary classification when training data is not balanced.
Is putting this option on the OptionsBase
class correct if it is only for binary classification?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. The cause of the problem was this commit which split the trainers in to their classes (binary and multiclass). is it better not to include it as an argument in this file?
/// Whether training data is unbalanced. Used by <see cref="LightGbmBinaryTrainer"/>. | ||
/// </summary> | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Use for binary classification when training data is not balanced.", ShortName = "us")] | ||
public bool UnbalancedSets = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this boolean being used? I don't see any code added using it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's added to the dictionary, NameMapping on line 39:
{nameof(OptionsBase.UnbalancedSets), "is_unbalance"},
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are declaring 2 new boolean fields. I only see 1 being used.
@@ -110,6 +116,7 @@ static Options() | |||
{ | |||
var res = base.ToDictionary(host); | |||
|
|||
res[GetOptionName(nameof(UnbalancedSets))] = UnbalancedSets; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
res[GetOptionName(nameof(UnbalancedSets))] = UnbalancedSets; [](start = 16, length = 60)
how is getting mapped to is_unbalance 🔗︎, default = false, type = bool, aliases: unbalance, unbalanced_sets #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nevermind I see it now:
public static string GetOptionName(string name)
{
// Otherwise convert the name to the light gbm argument
StringBuilder strBuf = new StringBuilder();
bool first = true;
foreach (char c in name)
{
if (char.IsUpper(c))
{
if (first)
first = false;
else
strBuf.Append('_');
strBuf.Append(char.ToLower(c));
}
else
strBuf.Append(c);
}
return strBuf.ToString();
}
In reply to: 299023864 [](ancestors = 299023864)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sigmoid = sigmoid // Custom sigmoid value. | ||
Sigmoid = sigmoid, // Custom sigmoid value. | ||
UnbalancedSets = unbalancedSets // false by default | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Empty line.
…ightGBM-unbalanced-data
…tnet#3925)" This reverts commit c5a18ef. # Conflicts: # test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs
* LightGBM unbalanced data arg added * unbalanced data argument added * tests for unbalanced LightGbm and added arg for multiclass * reverted changes on LightGbmArguments * wording improvement on unbalanced arg help text * updated manifest * removed empty line * added keytype to test
Fix for issue #3688. LightGBM Multiclass Trainer can now accept unbalanced data parameter as was previously possible in the Binary Trainer. An additional argument was added to the LightGBMBinaryEstimator test.