Skip to content

Commit 9b27c53

Browse files
authored
Pass group id via options (#2742)
1 parent be3c35e commit 9b27c53

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

src/Microsoft.ML.FastTree/FastTree.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
144144
/// Constructor that is used when invoking the classes deriving from this, through maml.
145145
/// </summary>
146146
private protected FastTreeTrainerBase(IHostEnvironment env, TOptions options, SchemaShape.Column label)
147-
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(options.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn))
147+
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(options.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn),
148+
TrainerUtils.MakeU4ScalarColumn(options.GroupIdColumn))
148149
{
149150
Host.CheckValue(options, nameof(options));
150151
FastTreeTrainerOptions = options;
@@ -1842,7 +1843,7 @@ private void MakeBoundariesAndCheckLabels(out long missingInstances, out long to
18421843
{
18431844
hasGroup = _data.Schema.Group != null;
18441845

1845-
if(hasGroup)
1846+
if (hasGroup)
18461847
curOptions |= CursOpt.Group;
18471848
}
18481849
else

src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ private protected LightGbmTrainerBase(IHostEnvironment env,
8484
}
8585

8686
private protected LightGbmTrainerBase(IHostEnvironment env, string name, Options options, SchemaShape.Column label)
87-
: base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(options.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn))
87+
: base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(options.FeatureColumn), label,
88+
TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn), TrainerUtils.MakeU4ScalarColumn(options.GroupIdColumn))
8889
{
8990
Host.CheckValue(options, nameof(options));
9091

@@ -163,7 +164,7 @@ private protected virtual void CheckDataValid(IChannel ch, RoleMappedData data)
163164
ch.CheckParam(data.Schema.Label.HasValue, nameof(data), "Need a label column");
164165
}
165166

166-
protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCategarical, int totalCats, bool hiddenMsg=false)
167+
protected virtual void GetDefaultParameters(IChannel ch, int numRow, bool hasCategarical, int totalCats, bool hiddenMsg = false)
167168
{
168169
double learningRate = LightGbmTrainerOptions.LearningRate ?? DefaultLearningRate(numRow, hasCategarical, totalCats);
169170
int numLeaves = LightGbmTrainerOptions.NumLeaves ?? DefaultNumLeaves(numRow, hasCategarical, totalCats);
@@ -584,7 +585,7 @@ private void CreateDatasetFromSamplingData(IChannel ch, FloatLabelCursor.Factory
584585
int[] nonZeroCntPerColumn = new int[catMetaData.NumCol];
585586
int estimateNonZeroCnt = (int)(numSampleRow * density);
586587
estimateNonZeroCnt = Math.Max(1, estimateNonZeroCnt);
587-
for(int i = 0; i < catMetaData.NumCol; i++)
588+
for (int i = 0; i < catMetaData.NumCol; i++)
588589
{
589590
nonZeroCntPerColumn[i] = 0;
590591
sampleValuePerColumn[i] = new double[estimateNonZeroCnt];

test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ public void FastTreeRankerEstimator()
115115
new FastTreeRankingTrainer.Options
116116
{
117117
FeatureColumn = "NumericFeatures",
118-
NumTrees = 10
118+
NumTrees = 10,
119+
GroupIdColumn = "Group"
119120
});
120121

121122
var pipeWithTrainer = pipe.Append(trainer);
@@ -134,7 +135,7 @@ public void LightGBMRankerEstimator()
134135
{
135136
var (pipe, dataView) = GetRankingPipeline();
136137

137-
var trainer = ML.Ranking.Trainers.LightGbm(labelColumnName: "Label0", featureColumnName: "NumericFeatures", rowGroupColumnName: "Group", learningRate: 0.4);
138+
var trainer = ML.Ranking.Trainers.LightGbm(new Options() { LabelColumn = "Label0", FeatureColumn = "NumericFeatures", GroupIdColumn = "Group", LearningRate = 0.4 });
138139

139140
var pipeWithTrainer = pipe.Append(trainer);
140141
TestEstimatorCore(pipeWithTrainer, dataView);

0 commit comments

Comments
 (0)