@@ -84,7 +84,8 @@ private protected LightGbmTrainerBase(IHostEnvironment env,
84
84
}
85
85
86
86
private protected LightGbmTrainerBase ( IHostEnvironment env , string name , Options options , SchemaShape . Column label )
87
- : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( name ) , TrainerUtils . MakeR4VecFeature ( options . FeatureColumn ) , label , TrainerUtils . MakeR4ScalarWeightColumn ( options . WeightColumn ) )
87
+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( name ) , TrainerUtils . MakeR4VecFeature ( options . FeatureColumn ) , label ,
88
+ TrainerUtils . MakeR4ScalarWeightColumn ( options . WeightColumn ) , TrainerUtils . MakeU4ScalarColumn ( options . GroupIdColumn ) )
88
89
{
89
90
Host . CheckValue ( options , nameof ( options ) ) ;
90
91
@@ -163,7 +164,7 @@ private protected virtual void CheckDataValid(IChannel ch, RoleMappedData data)
163
164
ch . CheckParam ( data . Schema . Label . HasValue , nameof ( data ) , "Need a label column" ) ;
164
165
}
165
166
166
- protected virtual void GetDefaultParameters ( IChannel ch , int numRow , bool hasCategarical , int totalCats , bool hiddenMsg = false )
167
+ protected virtual void GetDefaultParameters ( IChannel ch , int numRow , bool hasCategarical , int totalCats , bool hiddenMsg = false )
167
168
{
168
169
double learningRate = LightGbmTrainerOptions . LearningRate ?? DefaultLearningRate ( numRow , hasCategarical , totalCats ) ;
169
170
int numLeaves = LightGbmTrainerOptions . NumLeaves ?? DefaultNumLeaves ( numRow , hasCategarical , totalCats ) ;
@@ -584,7 +585,7 @@ private void CreateDatasetFromSamplingData(IChannel ch, FloatLabelCursor.Factory
584
585
int [ ] nonZeroCntPerColumn = new int [ catMetaData . NumCol ] ;
585
586
int estimateNonZeroCnt = ( int ) ( numSampleRow * density ) ;
586
587
estimateNonZeroCnt = Math . Max ( 1 , estimateNonZeroCnt ) ;
587
- for ( int i = 0 ; i < catMetaData . NumCol ; i ++ )
588
+ for ( int i = 0 ; i < catMetaData . NumCol ; i ++ )
588
589
{
589
590
nonZeroCntPerColumn [ i ] = 0 ;
590
591
sampleValuePerColumn [ i ] = new double [ estimateNonZeroCnt ] ;
0 commit comments