Skip to content

Commit

Permalink
[c-api] check number of features when retrieving number of bins (micr…
Browse files Browse the repository at this point in the history
…osoft#5183)

* check number of features when retrieving number of bins

* check for negative values

* lint
  • Loading branch information
jmoralez authored Apr 30, 2022
1 parent d893cd1 commit f53fa69
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1555,6 +1555,11 @@ int LGBM_DatasetGetFeatureNumBin(DatasetHandle handle,
int* out) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
int num_features = dataset->num_total_features();
if (feature < 0 || feature >= num_features) {
Log::Fatal("Tried to retrieve number of bins for feature index %d, "
"but the valid feature indices are [0, %d].", feature, num_features - 1);
}
int inner_idx = dataset->InnerFeatureIndex(feature);
if (inner_idx >= 0) {
*out = dataset->FeatureNumBin(inner_idx);
Expand Down
10 changes: 10 additions & 0 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,16 @@ def test_feature_num_bin(min_data_in_bin):
]
actual_num_bins = [ds.feature_num_bin(i) for i in range(X.shape[1])]
assert actual_num_bins == expected_num_bins
# check for feature indices outside of range
num_features = X.shape[1]
with pytest.raises(
lgb.basic.LightGBMError,
match=(
f'Tried to retrieve number of bins for feature index {num_features}, '
f'but the valid feature indices are \\[0, {num_features - 1}\\].'
)
):
ds.feature_num_bin(num_features)


def test_feature_num_bin_with_max_bin_by_feature():
Expand Down

0 comments on commit f53fa69

Please sign in to comment.