Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable warp-per-tree inference in FIL for regression and binary classification #3760

Merged
merged 40 commits into from
Jun 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8fedfd7
try 1
levsnv Apr 17, 2021
80b919f
fixed python test
levsnv Apr 18, 2021
f54ffa6
copyright year
levsnv Apr 18, 2021
18651ce
enhanced python tests; threaded n_items through python layer; one deb…
levsnv Apr 18, 2021
3f4a5d6
refactor
levsnv Apr 28, 2021
9e5c5a4
style
levsnv Apr 28, 2021
8ae9e5f
style
levsnv Apr 28, 2021
5149604
addressed some review comments
levsnv Apr 28, 2021
cc7e365
Merge remote-tracking branch 'levs/refactor-cython-kwargs' into warp-…
levsnv Apr 28, 2021
aa40a64
fixed all bugs
levsnv May 4, 2021
39e05af
Merge remote-tracking branch 'rapidsai/branch-0.20' into warp-per-tre…
levsnv May 4, 2021
e4742d2
fixed some bugs; documentation; using int_fastdiv for inner loop in p…
levsnv May 14, 2021
af7406b
doc text
levsnv May 14, 2021
b01a543
almost all gains can be made without using Maxim's particular fastdiv…
levsnv May 19, 2021
8b63f7c
style
levsnv May 19, 2021
5627d30
successfully sped up warp/tree for tpt < 32 and even tpt == 32
levsnv May 20, 2021
a701332
remove unnecessary __syncwarp() and extra LOC
levsnv May 20, 2021
8d70668
use CUB for 1 thread/tree
levsnv May 22, 2021
6546955
remove syncthreads; use threads/tree to determine shmem footprint
levsnv May 22, 2021
e989cd8
Merge remote-tracking branch 'rapidsai/branch-21.06' into warp-per-tr…
levsnv May 24, 2021
2502b65
convert last value to fixed width
levsnv May 26, 2021
f0762ff
Merge remote-tracking branch 'rapidsai/branch-21.08' into warp-per-tr…
levsnv May 26, 2021
4b78439
Merge remote-tracking branch 'rapidsai/branch-21.08' into warp-per-tr…
levsnv Jun 2, 2021
90b1292
FIX fix kernel and line info in cmake
dantegd Jun 3, 2021
36c80a2
FIX Use ucx-py 0.21
dantegd Jun 3, 2021
8594eda
DBG Try installing xgb 21.06 after the big conda install
dantegd Jun 5, 2021
0402975
DBG Try using mamba to avoid timeouts
dantegd Jun 6, 2021
5a5d184
DBG install xgboost after libcuml artifact
dantegd Jun 6, 2021
e1bca08
DBG Playing with using mamba a little bit more
dantegd Jun 6, 2021
24b3198
DBG Remove xgboost instead of all of the above
dantegd Jun 7, 2021
9692b43
DBG correct ucx-py version
dantegd Jun 7, 2021
0f28f31
FIX Change order of commented code to make the script happy
dantegd Jun 7, 2021
3440b8e
Merge branch 'branch-21.08' into 2108-fix-likerinfo
dantegd Jun 7, 2021
8d20def
FIX Merge main and use dask main
dantegd Jun 7, 2021
3f4d875
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into warp-per…
levsnv Jun 8, 2021
71e0685
FIX add back xgboost now that package is published
dantegd Jun 8, 2021
2bf79d9
Merge branch '2108-fix-likerinfo' of github.com:dantegd/cuml into war…
levsnv Jun 9, 2021
e71ef74
added void accumulate(..., int num_rows) back since vector leaf will …
levsnv Jun 9, 2021
7b1f2d3
readability
levsnv Jun 9, 2021
cec71e5
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into warp-per…
levsnv Jun 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/bench/sg/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ class FIL : public RegressionFixture<float> {
.output_class = params.nclasses > 1, // cuML RF forest
.threshold = 1.f / params.nclasses, //Fixture::DatasetParams
.storage_type = p_rest.storage,
.blocks_per_sm = 0,
.blocks_per_sm = 8,
.threads_per_tree = 1,
.n_items = 0,
.pforest_shape_str = nullptr};
ML::fil::from_treelite(*handle, &forest, model, &tl_params);

Expand Down
6 changes: 6 additions & 0 deletions cpp/include/cuml/fil/fil.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ struct treelite_params_t {
// suggested values (if nonzero) are from 2 to 7
// if zero, launches ceildiv(num_rows, NITEMS) blocks
int blocks_per_sm;
// threads_per_tree determines how many threads work on a single tree at once inside a block
// can only be a power of 2
int threads_per_tree;
// n_items is how many input samples (items) any thread processes. If 0 is given,
// choose most (up to 4) that fit into shared memory.
int n_items;
// if non-nullptr, *pforest_shape_str will be set to caller-owned string that
// contains forest shape
char** pforest_shape_str;
Expand Down
17 changes: 13 additions & 4 deletions cpp/src/fil/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,22 @@ struct shmem_size_params {
/// are the input columns are prefetched into shared
/// memory before inferring the row in question
bool cols_in_shmem = true;
/// n_items is the most items per thread that fit into shared memory
/// log2_threads_per_tree determines how many threads work on a single tree
/// at once inside a block (sharing trees means splitting input rows)
int log2_threads_per_tree = 0;
/// n_items is how many input samples (items) any thread processes. If 0 is given,
/// choose the reasonable most (<=4) that fit into shared memory. See init_n_items()
int n_items = 0;
/// shm_sz is the associated shared memory footprint
int shm_sz = INT_MAX;

__host__ __device__ size_t cols_shmem_size() {
return cols_in_shmem ? sizeof(float) * num_cols * n_items : 0;
__host__ __device__ int sdata_stride() {
return num_cols | 1; // pad to odd
}
__host__ __device__ int cols_shmem_size() {
return cols_in_shmem
? sizeof(float) * sdata_stride() * n_items << log2_threads_per_tree
: 0;
}
void compute_smem_footprint();
template <int NITEMS>
Expand All @@ -148,7 +157,7 @@ struct predict_params : shmem_size_params {
float* preds;
const float* data;
// number of data rows (instances) to predict on
size_t num_rows;
int64_t num_rows;

// to signal infer kernel to apply softmax and also average prior to that
// for GROVE_PER_CLASS for predict_proba
Expand Down
33 changes: 31 additions & 2 deletions cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ struct forest {
int max_shm = 0;
CUDA_CHECK(cudaDeviceGetAttribute(
&max_shm, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
/* Our GPUs have been growing the shared memory size generation after
generation. Eventually, a CUDA GPU might come by that supports more
shared memory that would fit into unsigned 16-bit int. For such a GPU,
we would have otherwise silently overflowed the index calculation due
to short division. It would have failed cpp tests, but we might forget
about this source of bugs, if not for the failing assert. */
ASSERT(max_shm < 262144,
"internal error: please use a larger type inside"
" infer_k for column count");
// TODO(canonizer): use >48KiB shared memory if available
max_shm = std::min(max_shm, max_shm_std);

Expand All @@ -91,10 +100,14 @@ struct forest {
shmem_size_params& ssp_ = predict_proba ? proba_ssp_ : class_ssp_;
ssp_.predict_proba = predict_proba;
shmem_size_params ssp = ssp_;
// if n_items was not provided, try from 1 to 4. Otherwise, use as-is.
int min_n_items = ssp.n_items == 0 ? 1 : ssp.n_items;
int max_n_items = ssp.n_items == 0
? (algo_ == algo_t::BATCH_TREE_REORG ? 4 : 1)
: ssp.n_items;
for (bool cols_in_shmem : {false, true}) {
ssp.cols_in_shmem = cols_in_shmem;
for (ssp.n_items = 1;
ssp.n_items <= (algo_ == algo_t::BATCH_TREE_REORG ? 4 : 1);
for (ssp.n_items = min_n_items; ssp.n_items <= max_n_items;
++ssp.n_items) {
ssp.compute_smem_footprint();
if (ssp.shm_sz < max_shm) ssp_ = ssp;
Expand Down Expand Up @@ -126,6 +139,8 @@ struct forest {
output_ = params->output;
threshold_ = params->threshold;
global_bias_ = params->global_bias;
proba_ssp_.n_items = params->n_items;
proba_ssp_.log2_threads_per_tree = log2(params->threads_per_tree);
proba_ssp_.leaf_algo = params->leaf_algo;
proba_ssp_.num_cols = params->num_cols;
proba_ssp_.num_classes = params->num_classes;
Expand Down Expand Up @@ -412,12 +427,16 @@ void check_params(const forest_params_t* params, bool dense) {
"softmax does not make sense for leaf_algo == FLOAT_UNARY_BINARY");
break;
case leaf_algo_t::GROVE_PER_CLASS:
ASSERT(params->threads_per_tree == 1,
"multiclass not supported with threads_per_tree > 1");
ASSERT(params->num_classes > 2,
"num_classes > 2 is required for leaf_algo == GROVE_PER_CLASS");
ASSERT(params->num_trees % params->num_classes == 0,
"num_classes must divide num_trees evenly for GROVE_PER_CLASS");
break;
case leaf_algo_t::CATEGORICAL_LEAF:
ASSERT(params->threads_per_tree == 1,
"multiclass not supported with threads_per_tree > 1");
ASSERT(params->num_classes >= 2,
"num_classes >= 2 is required for "
"leaf_algo == CATEGORICAL_LEAF");
Expand All @@ -437,6 +456,14 @@ void check_params(const forest_params_t* params, bool dense) {
ASSERT(~params->output & (output_t::SIGMOID | output_t::SOFTMAX),
"combining softmax and sigmoid is not supported");
ASSERT(params->blocks_per_sm >= 0, "blocks_per_sm must be nonnegative");
ASSERT(params->n_items >= 0, "n_items must be non-negative");
ASSERT(params->threads_per_tree > 0, "threads_per_tree must be positive");
ASSERT(thrust::detail::is_power_of_2(params->threads_per_tree),
"threads_per_tree must be a power of 2");
ASSERT(params->threads_per_tree <= FIL_TPB,
"threads_per_tree must not "
"exceed block size %d",
FIL_TPB);
}

template <typename T, typename L>
Expand Down Expand Up @@ -806,6 +833,8 @@ void tl2fil_common(forest_params_t* params, const tl::ModelImpl<T, L>& model,
params->output = output_t(params->output | output_t::SOFTMAX);
params->num_trees = model.trees.size();
params->blocks_per_sm = tl_params->blocks_per_sm;
params->threads_per_tree = tl_params->threads_per_tree;
params->n_items = tl_params->n_items;
}

// uses treelite model with additional tl_params to initialize FIL params
Expand Down
132 changes: 93 additions & 39 deletions cpp/src/fil/infer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,10 @@ struct tree_aggregator_t {
value is computed.
num_classes is used for other template parameters */
static size_t smem_finalize_footprint(size_t data_row_size, int num_classes,
int log2_threads_per_tree,
bool predict_proba) {
return block_reduce_footprint_host<NITEMS>();
return log2_threads_per_tree != 0 ? FIL_TPB * NITEMS * sizeof(float)
: block_reduce_footprint_host<NITEMS>();
}

/** shared memory footprint of the accumulator during
Expand All @@ -236,19 +238,41 @@ struct tree_aggregator_t {
: tmp_storage(finalize_workspace) {}

__device__ __forceinline__ void accumulate(
vec<NITEMS, float> single_tree_prediction, int tree, int num_rows) {
vec<NITEMS, float> single_tree_prediction, int tree, int thread_num_rows) {
acc += single_tree_prediction;
}

__device__ __forceinline__ void finalize(float* out, int num_rows,
__device__ __forceinline__ void finalize(float* block_out, int block_num_rows,
int output_stride,
output_t transform, int num_trees) {
__syncthreads();
acc = block_reduce(acc, vectorized(cub::Sum()), tmp_storage);
if (threadIdx.x > 0) return;
output_t transform, int num_trees,
int log2_threads_per_tree) {
if (FIL_TPB != 1 << log2_threads_per_tree) { // anything to reduce?
// ensure input columns can be overwritten (no threads traversing trees)
__syncthreads();
if (log2_threads_per_tree == 0) {
acc = block_reduce(acc, vectorized(cub::Sum()), tmp_storage);
} else {
auto per_thread = (vec<NITEMS, float>*)tmp_storage;
per_thread[threadIdx.x] = acc;
__syncthreads();
// We have two pertinent cases for splitting FIL_TPB == 256 values:
// 1. 2000 columns, which fit few threads/tree in shared memory,
// so ~256 groups. These are the models that will run the slowest.
// multi_sum performance is not sensitive to the radix here.
// 2. 50 columns, so ~32 threads/tree, so ~8 groups. These are the most
// popular.
acc = multi_sum<5>(per_thread, 1 << log2_threads_per_tree,
FIL_TPB >> log2_threads_per_tree);
}
}

if (threadIdx.x * NITEMS >= block_num_rows) return;
#pragma unroll
for (int row = 0; row < NITEMS; ++row)
if (row < num_rows) out[row * output_stride] = acc[row];
for (int row = 0; row < NITEMS; ++row) {
int out_preds_i = threadIdx.x * NITEMS + row;
if (out_preds_i < block_num_rows)
block_out[out_preds_i * output_stride] = acc[row];
}
}
};

Expand Down Expand Up @@ -362,6 +386,7 @@ struct tree_aggregator_t<NITEMS, GROVE_PER_CLASS_FEW_CLASSES> {
void* tmp_storage;

static size_t smem_finalize_footprint(size_t data_row_size, int num_classes,
int log2_threads_per_tree,
bool predict_proba) {
size_t phase1 =
(FIL_TPB - FIL_TPB % num_classes) * sizeof(vec<NITEMS, float>);
Expand All @@ -382,13 +407,14 @@ struct tree_aggregator_t<NITEMS, GROVE_PER_CLASS_FEW_CLASSES> {
: finalize_workspace) {}

__device__ __forceinline__ void accumulate(
vec<NITEMS, float> single_tree_prediction, int tree, int num_rows) {
vec<NITEMS, float> single_tree_prediction, int tree, int thread_num_rows) {
acc += single_tree_prediction;
}

__device__ __forceinline__ void finalize(float* out, int num_rows,
int num_outputs, output_t transform,
int num_trees) {
int num_trees,
int log2_threads_per_tree) {
__syncthreads(); // free up input row in case it was in shared memory
// load margin into shared memory
per_thread[threadIdx.x] = acc;
Expand All @@ -412,6 +438,7 @@ struct tree_aggregator_t<NITEMS, GROVE_PER_CLASS_MANY_CLASSES> {
int num_classes;

static size_t smem_finalize_footprint(size_t data_row_size, int num_classes,
int log2_threads_per_tree,
bool predict_proba) {
size_t phase1 = data_row_size + smem_accumulate_footprint(num_classes);
size_t phase2 = predict_proba
Expand All @@ -437,15 +464,16 @@ struct tree_aggregator_t<NITEMS, GROVE_PER_CLASS_MANY_CLASSES> {
}

__device__ __forceinline__ void accumulate(
vec<NITEMS, float> single_tree_prediction, int tree, int num_rows) {
vec<NITEMS, float> single_tree_prediction, int tree, int thread_num_rows) {
// since threads are assigned to consecutive classes, no need for atomics
per_class_margin[tree % num_classes] += single_tree_prediction;
// __syncthreads() is called in infer_k
}

__device__ __forceinline__ void finalize(float* out, int num_rows,
int num_outputs, output_t transform,
int num_trees) {
int num_trees,
int log2_threads_per_tree) {
class_margins_to_global_memory(
per_class_margin, per_class_margin + num_classes, transform,
num_trees / num_classes, tmp_storage, out, num_rows, num_outputs);
Expand All @@ -454,12 +482,13 @@ struct tree_aggregator_t<NITEMS, GROVE_PER_CLASS_MANY_CLASSES> {

template <int NITEMS>
struct tree_aggregator_t<NITEMS, CATEGORICAL_LEAF> {
// could switch to unsigned short to save shared memory
// could switch to uint16_t to save shared memory
// provided raft::myAtomicAdd(short*) simulated with appropriate shifts
int* votes;
int num_classes;

static size_t smem_finalize_footprint(size_t data_row_size, int num_classes,
int log2_threads_per_tree,
bool predict_proba) {
// not accounting for lingering accumulate_footprint during finalize()
return 0;
Expand All @@ -478,7 +507,7 @@ struct tree_aggregator_t<NITEMS, CATEGORICAL_LEAF> {
// __syncthreads() is called in infer_k
}
__device__ __forceinline__ void accumulate(
vec<NITEMS, int> single_tree_prediction, int tree, int num_rows) {
vec<NITEMS, int> single_tree_prediction, int tree, int thread_num_rows) {
#pragma unroll
for (int item = 0; item < NITEMS; ++item)
raft::myAtomicAdd(votes + single_tree_prediction[item] * NITEMS + item,
Expand Down Expand Up @@ -516,7 +545,8 @@ struct tree_aggregator_t<NITEMS, CATEGORICAL_LEAF> {
}
__device__ __forceinline__ void finalize(float* out, int num_rows,
int num_outputs, output_t transform,
int num_trees) {
int num_trees,
int log2_threads_per_tree) {
if (num_outputs > 1) {
// only supporting num_outputs == num_classes
finalize_multiple_outputs(out, num_rows);
Expand All @@ -531,45 +561,69 @@ template <int NITEMS, leaf_algo_t leaf_algo, bool cols_in_shmem,
__global__ void infer_k(storage_type forest, predict_params params) {
extern __shared__ char smem[];
float* sdata = (float*)smem;
int sdata_stride = params.sdata_stride();
int rows_per_block = NITEMS << params.log2_threads_per_tree;
int num_cols = params.num_cols;
for (size_t block_row0 = blockIdx.x * NITEMS; block_row0 < params.num_rows;
block_row0 += NITEMS * gridDim.x) {
size_t num_input_rows = min((size_t)NITEMS, params.num_rows - block_row0);
int thread_row0 = NITEMS * modpow2(threadIdx.x, params.log2_threads_per_tree);
for (int64_t block_row0 = blockIdx.x * rows_per_block;
block_row0 < params.num_rows; block_row0 += rows_per_block * gridDim.x) {
int block_num_rows = max(
0,
(int)min((int64_t)rows_per_block, (int64_t)params.num_rows - block_row0));
const float* block_input = params.data + block_row0 * num_cols;
if (cols_in_shmem) {
// cache the row for all threads to reuse
size_t feature = 0;
// 2021: latest SMs still do not have >256KiB of shared memory/block required to
// exceed the uint16_t
#pragma unroll
for (feature = threadIdx.x; feature < num_input_rows * num_cols;
feature += blockDim.x)
sdata[feature] = block_input[feature];
for (uint16_t input_idx = threadIdx.x;
input_idx < block_num_rows * num_cols; input_idx += blockDim.x) {
// for even num_cols, we need to pad sdata_stride to reduce bank conflicts
// assuming here that sdata_stride == num_cols + 1
// then, idx / num_cols * sdata_stride + idx % num_cols == idx + idx / num_cols
uint16_t sdata_idx = sdata_stride == num_cols
? input_idx
: input_idx + input_idx / (uint16_t)num_cols;
sdata[sdata_idx] = block_input[input_idx];
}
#pragma unroll
for (; feature < NITEMS * num_cols; feature += blockDim.x)
sdata[feature] = 0.0f;
for (int idx = block_num_rows * sdata_stride;
idx < rows_per_block * sdata_stride; idx += blockDim.x)
sdata[idx] = 0.0f;
}

tree_aggregator_t<NITEMS, leaf_algo> acc(
params, (char*)sdata + params.cols_shmem_size(), sdata);

__syncthreads(); // for both row cache init and acc init

// one block works on NITEMS rows and the whole forest
for (int j = threadIdx.x; j - threadIdx.x < forest.num_trees();
j += blockDim.x) {
/* j - threadIdx.x < forest.num_trees() is a necessary but block-uniform
condition for "j < forest.num_trees()". It lets use __syncthreads()
// one block works on NITEMS * threads_per_tree rows and the whole forest
// one thread works on NITEMS rows

int thread_tree0 = threadIdx.x >> params.log2_threads_per_tree;
int tree_stride = blockDim.x >> params.log2_threads_per_tree;
int thread_num_rows = max(0, min(NITEMS, block_num_rows - thread_row0));
for (int tree = thread_tree0; tree - thread_tree0 < forest.num_trees();
tree += tree_stride) {
/* tree - thread_tree0 < forest.num_trees() is a necessary but block-uniform
condition for "tree < forest.num_trees()". It lets use __syncthreads()
and is made exact below.
Same with thread_num_rows > 0
*/
if (j < forest.num_trees()) {
acc.accumulate(infer_one_tree<NITEMS, leaf_output_t<leaf_algo>::T>(
forest[j], cols_in_shmem ? sdata : block_input,
num_cols, num_input_rows),
j, num_input_rows);
if (tree < forest.num_trees() && thread_num_rows != 0) {
typedef typename leaf_output_t<leaf_algo>::T pred_t;
vec<NITEMS, pred_t> prediction = infer_one_tree<NITEMS, pred_t>(
forest[tree],
cols_in_shmem ? sdata + thread_row0 * sdata_stride
: block_input + thread_row0 * num_cols,
cols_in_shmem ? sdata_stride : num_cols,
cols_in_shmem ? NITEMS : thread_num_rows);
acc.accumulate(prediction, tree, thread_num_rows);
}
if (leaf_algo == GROVE_PER_CLASS_MANY_CLASSES) __syncthreads();
}
acc.finalize(params.preds + params.num_outputs * block_row0, num_input_rows,
params.num_outputs, params.transform, forest.num_trees());
acc.finalize(params.preds + params.num_outputs * block_row0, block_num_rows,
params.num_outputs, params.transform, forest.num_trees(),
params.log2_threads_per_tree);
__syncthreads(); // free up acc's shared memory resources for next row set
}
}
Expand All @@ -578,7 +632,7 @@ template <int NITEMS, leaf_algo_t leaf_algo>
size_t shmem_size_params::get_smem_footprint() {
size_t finalize_footprint =
tree_aggregator_t<NITEMS, leaf_algo>::smem_finalize_footprint(
cols_shmem_size(), num_classes, predict_proba);
cols_shmem_size(), num_classes, log2_threads_per_tree, predict_proba);
size_t accumulate_footprint =
tree_aggregator_t<NITEMS, leaf_algo>::smem_accumulate_footprint(
num_classes) +
Expand Down
Loading