@@ -805,6 +805,7 @@ Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
805
805
cp_ptr += bin_mappers[i]->SizesInByte ();
806
806
}
807
807
}
808
+ CheckCategoricalFeatureNumBin (bin_mappers, config_.max_bin , config_.max_bin_by_feature );
808
809
auto dataset = std::unique_ptr<Dataset>(new Dataset (num_data));
809
810
dataset->Construct (&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
810
811
if (dataset->has_raw ()) {
@@ -1184,6 +1185,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
1184
1185
cp_ptr += bin_mappers[i]->SizesInByte ();
1185
1186
}
1186
1187
}
1188
+ CheckCategoricalFeatureNumBin (bin_mappers, config_.max_bin , config_.max_bin_by_feature );
1187
1189
dataset->Construct (&bin_mappers, dataset->num_total_features_ , forced_bin_bounds, Common::Vector2Ptr<int >(&sample_indices).data (),
1188
1190
Common::Vector2Ptr<double >(&sample_values).data (),
1189
1191
Common::VectorSize<int >(sample_indices).data (), static_cast <int >(sample_indices.size ()), sample_data.size (), config_);
@@ -1463,4 +1465,44 @@ std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced
1463
1465
return forced_bins;
1464
1466
}
1465
1467
1468
+ void DatasetLoader::CheckCategoricalFeatureNumBin (
1469
+ const std::vector<std::unique_ptr<BinMapper>>& bin_mappers,
1470
+ const int max_bin, const std::vector<int >& max_bin_by_feature) const {
1471
+ bool need_warning = false ;
1472
+ if (bin_mappers.size () < 1024 ) {
1473
+ for (size_t i = 0 ; i < bin_mappers.size (); ++i) {
1474
+ const int max_bin_for_this_feature = max_bin_by_feature.empty () ? max_bin : max_bin_by_feature[i];
1475
+ if (bin_mappers[i]->bin_type () == BinType::CategoricalBin && bin_mappers[i]->num_bin () > max_bin_for_this_feature) {
1476
+ need_warning = true ;
1477
+ break ;
1478
+ }
1479
+ }
1480
+ } else {
1481
+ const int num_threads = OMP_NUM_THREADS ();
1482
+ std::vector<bool > thread_need_warning (num_threads, false );
1483
+ Threading::For<size_t >(0 , bin_mappers.size (), 1 ,
1484
+ [&bin_mappers, &thread_need_warning, &max_bin_by_feature, max_bin] (int thread_index, size_t start, size_t end) {
1485
+ for (size_t i = start; i < end; ++i) {
1486
+ thread_need_warning[thread_index] = false ;
1487
+ const int max_bin_for_this_feature = max_bin_by_feature.empty () ? max_bin : max_bin_by_feature[i];
1488
+ if (bin_mappers[i]->bin_type () == BinType::CategoricalBin && bin_mappers[i]->num_bin () > max_bin_for_this_feature) {
1489
+ thread_need_warning[thread_index] = true ;
1490
+ break ;
1491
+ }
1492
+ }
1493
+ });
1494
+ for (int thread_index = 0 ; thread_index < num_threads; ++thread_index) {
1495
+ if (thread_need_warning[thread_index]) {
1496
+ need_warning = true ;
1497
+ break ;
1498
+ }
1499
+ }
1500
+ }
1501
+
1502
+ if (need_warning) {
1503
+ Log::Warning (" Categorical features with more bins than the configured maximum bin number found." );
1504
+ Log::Warning (" For categorical features, max_bin and max_bin_by_feature may be ignored with a large number of categories." );
1505
+ }
1506
+ }
1507
+
1466
1508
} // namespace LightGBM
0 commit comments