Skip to content

Commit

Permalink
Add check to prevent out of index lookup in the position discount tab…
Browse files Browse the repository at this point in the history
…le. Add debug logging to report number of queries found in the data.
  • Loading branch information
ashok-ponnuswami-msft committed Mar 13, 2021
1 parent 37e9878 commit 73e7ab4
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 0 deletions.
8 changes: 8 additions & 0 deletions include/LightGBM/metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ class DCGCalculator {
static double CalMaxDCGAtK(data_size_t k,
const label_t* label, data_size_t num_data);


/*!
* \brief Check the metadata for NDCG and lambdarank
* \param metadata Metadata
* \param num_queries Number of queries
*/
static void CheckMetadata(const Metadata& metadata, data_size_t num_queries);

/*!
* \brief Check the label range for NDCG and lambdarank
* \param label Pointer of label
Expand Down
2 changes: 2 additions & 0 deletions src/io/metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
for (size_t i = 0; i < tmp_buffer.size(); ++i) {
query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
}
Log::Info("Number of queries: %i. Average number of rows per query: %f.",
static_cast<int>(num_queries_), static_cast<double>(num_data_) / num_queries_);
LoadQueryWeights();
queries_.clear();
}
Expand Down
13 changes: 13 additions & 0 deletions src/metric/dcg_calculator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ void DCGCalculator::CalDCG(const std::vector<data_size_t>& ks, const label_t* la
}
}

void DCGCalculator::CheckMetadata(const Metadata& metadata, data_size_t num_queries) {
const data_size_t* query_boundaries = metadata.query_boundaries();
if (num_queries > 0 && query_boundaries != nullptr) {
for (data_size_t i = 0; i < num_queries; i++) {
data_size_t num_rows = query_boundaries[i + 1] - query_boundaries[i];
if (num_rows > kMaxPosition) {
Log::Fatal("Number of rows %i exceeds upper limit of %i for a query", static_cast<int>(num_rows), static_cast<int>(kMaxPosition));
}
}
}
}


void DCGCalculator::CheckLabel(const label_t* label, data_size_t num_data) {
for (data_size_t i = 0; i < num_data; ++i) {
label_t delta = std::fabs(label[i] - static_cast<int>(label[i]));
Expand Down
1 change: 1 addition & 0 deletions src/objective/rank_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class LambdarankNDCG : public RankingObjective {

void Init(const Metadata& metadata, data_size_t num_data) override {
RankingObjective::Init(metadata, num_data);
DCGCalculator::CheckMetadata(metadata, num_queries_);
DCGCalculator::CheckLabel(label_, num_data_);
inverse_max_dcgs_.resize(num_queries_);
#pragma omp parallel for schedule(static)
Expand Down

0 comments on commit 73e7ab4

Please sign in to comment.