Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RF integer overflow #4563

Merged
merged 3 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ DI void partitionSamples(const Dataset<DataT, LabelT, IdxT>& dataset,
auto* rcomp = reinterpret_cast<IdxT*>(smem + smemSize);
auto range_start = work_item.instances.begin;
auto range_len = work_item.instances.count;
auto* col = dataset.data + split.colid * dataset.M;
auto* col = dataset.data + split.colid * std::size_t(dataset.M);
auto loffset = range_start, part = loffset + split.nLeft, roffset = part;
auto end = range_start + range_len;
int lflag = 0, rflag = 0, llen = 0, rlen = 0, minlen = 0;
Expand Down Expand Up @@ -286,7 +286,9 @@ __global__ void computeSplitKernel(BinT* histograms,
__syncthreads();

// compute pdf shared histogram for all bins for all classes in shared mem
auto col_offset = col * dataset.M;

// Must be 64 bit - can easily grow larger than a 32 bit int
std::size_t col_offset = std::size_t(col) * dataset.M;
for (auto i = range_start + tid; i < end; i += stride) {
// each thread works over a data point and strides to the next
auto row = dataset.row_ids[i];
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/randomforest/randomforest.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ class RandomForest {
std::vector<L> h_predictions(n_rows);
cudaStream_t stream = user_handle.get_stream();

std::vector<T> h_input(n_rows * n_cols);
raft::update_host(h_input.data(), input, n_rows * n_cols, stream);
std::vector<T> h_input(std::size_t(n_rows) * n_cols);
raft::update_host(h_input.data(), input, std::size_t(n_rows) * n_cols, stream);
user_handle.sync_stream(stream);

int row_size = n_cols;
Expand Down
50 changes: 45 additions & 5 deletions cpp/test/sg/rf_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ std::vector<ParamT> SampleParameters(int num_samples, size_t seed, Args... args)
}

struct RfTestParams {
int n_rows;
int n_cols;
std::size_t n_rows;
std::size_t n_cols;
int n_trees;
float max_features;
float max_samples;
Expand All @@ -120,8 +120,8 @@ struct RfTestParams {
bool double_precision;
// c++ has no reflection, so we enumerate the types here
// This must be updated if new fields are added
using types = std::tuple<int,
int,
using types = std::tuple<std::size_t,
std::size_t,
int,
float,
float,
Expand Down Expand Up @@ -333,7 +333,7 @@ class RfSpecialisedTest {

EXPECT_LE(forest->trees[i]->depth_counter, params.max_depth);
EXPECT_LE(forest->trees[i]->leaf_counter,
raft::ceildiv(params.n_rows, params.min_samples_leaf));
raft::ceildiv(int(params.n_rows), params.min_samples_leaf));
}
}

Expand Down Expand Up @@ -523,6 +523,46 @@ INSTANTIATE_TEST_CASE_P(RfTests,
n_labels,
double_precision)));

TEST(RfTests, IntegerOverflow)
{
std::size_t m = 1000000;
std::size_t n = 2150;
EXPECT_GE(m * n, 1ull << 31);
thrust::device_vector<float> X(m * n);
thrust::device_vector<float> y(m);
raft::random::Rng r(4);
r.normal(X.data().get(), X.size(), 0.0f, 2.0f, nullptr);
r.normal(y.data().get(), y.size(), 0.0f, 2.0f, nullptr);
auto forest = std::make_shared<RandomForestMetaData<float, float>>();
auto forest_ptr = forest.get();
auto stream_pool = std::make_shared<rmm::cuda_stream_pool>(4);
raft::handle_t handle(rmm::cuda_stream_per_thread, stream_pool);
RF_params rf_params =
set_rf_params(3, 100, 1.0, 256, 1, 2, 0.0, false, 1, 1.0, 0, CRITERION::MSE, 4, 128);
fit(handle, forest_ptr, X.data().get(), m, n, y.data().get(), rf_params);

// Check we have actually learned something
EXPECT_GT(forest->trees[0]->leaf_counter, 1);

// See if fil overflows
thrust::device_vector<float> pred(m);
ModelHandle model;
build_treelite_forest(&model, forest_ptr, n);

std::size_t num_outputs = 1;
fil::treelite_params_t tl_params{fil::algo_t::ALGO_AUTO,
num_outputs > 1,
1.f / num_outputs,
fil::storage_type_t::AUTO,
8,
1,
0,
nullptr};
fil::forest_t fil_forest;
fil::from_treelite(handle, &fil_forest, model, &tl_params);
fil::predict(handle, fil_forest, pred.data().get(), X.data().get(), m, false);
}

//-------------------------------------------------------------------------------------------------------------------------------------
struct QuantileTestParameters {
int n_rows;
Expand Down