Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use opt-in shared memory carveout for FIL #3759

Merged
merged 24 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
53e0683
try 1
levsnv Apr 17, 2021
31b885f
draft of set-and-launch
levsnv May 25, 2021
589f9ad
Merge remote-tracking branch 'rapidsai/branch-21.06' into extra-share…
levsnv May 25, 2021
26480b0
set carveout and occupancy-affecting preferred cache config before ev…
levsnv May 26, 2021
a037f16
other review comments
levsnv May 26, 2021
2a1d622
DRY: rewrote in terms of dispatch_on_FIL_template_params<func, storag…
levsnv Jun 12, 2021
bd3c505
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into extra-sh…
levsnv Jun 12, 2021
5cf38d3
style, clean up diff
levsnv Jun 12, 2021
7959f4d
Merge branch 'branch-21.08' of github.com:rapidsai/cuml into extra-sh…
levsnv Jun 14, 2021
258e674
fixed bugs and linker issues
levsnv Jun 15, 2021
e0f53ea
removed unnecessary specialization in dispatch
levsnv Jun 15, 2021
3dcece3
Merge remote-tracking branch 'rapidsai/branch-21.12' into extra-share…
levsnv Nov 3, 2021
1a198f3
Merge branch 'branch-21.12' of github.com:rapidsai/cuml into extra-sh…
levsnv Nov 3, 2021
1237797
updated to new dispatch* changes
levsnv Nov 3, 2021
b8ef50d
fixed old shm_sz placement
levsnv Nov 3, 2021
e0ef92e
noinline
levsnv Nov 3, 2021
a782c79
removed unused method
levsnv Nov 3, 2021
43967f0
Apply suggestions from code review
levsnv Nov 9, 2021
0f00f39
some review comments
levsnv Nov 9, 2021
beb2ed8
Merge branch 'extra-shared-memory' of github.com:levsnv/cuml into ext…
levsnv Nov 9, 2021
73a0d10
misc
levsnv Nov 9, 2021
7938c13
simplified the code by unconditionally enabling opt-in at model load
levsnv Nov 9, 2021
161065d
same type on both sides of the comparison
levsnv Nov 10, 2021
27bdced
static_cast instead of C-style cast
levsnv Nov 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cpp/src/fil/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,13 @@ struct compute_smem_footprint : dispatch_functor<int> {
int run(predict_params);
};

template <int NITEMS,
leaf_algo_t leaf_algo,
bool cols_in_shmem,
bool CATS_SUPPORTED,
class storage_type>
__global__ void infer_k(storage_type forest, predict_params params);

// infer() calls the inference kernel with the parameters on the stream
template <typename storage_type>
void infer(storage_type forest, predict_params params, cudaStream_t stream);
Expand Down
46 changes: 37 additions & 9 deletions cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,23 @@ extern template int dispatch_on_fil_template_params(compute_smem_footprint, pred
struct forest {
forest(const raft::handle_t& h) : vector_leaf_(0, h.get_stream()), cat_sets_(h.get_stream()) {}

void init_n_items(int device)
void init_shmem_size(int device)
{
int max_shm_std = 48 * 1024; // 48 KiB
/// the most shared memory a kernel can request on the GPU in question
int max_shm = 0;
CUDA_CHECK(cudaDeviceGetAttribute(&max_shm, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
CUDA_CHECK(cudaDeviceGetAttribute(&max_shm_, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
/* Our GPUs have been growing the shared memory size generation after
generation. Eventually, a CUDA GPU might come by that supports more
shared memory that would fit into unsigned 16-bit int. For such a GPU,
we would have otherwise silently overflowed the index calculation due
to short division. It would have failed cpp tests, but we might forget
about this source of bugs, if not for the failing assert. */
ASSERT(max_shm < 262144,
ASSERT(max_shm_ < int(sizeof(float)) * std::numeric_limits<uint16_t>::max(),
"internal error: please use a larger type inside"
" infer_k for column count");
// TODO(canonizer): use >48KiB shared memory if available
max_shm = std::min(max_shm, max_shm_std);
}

void init_n_items(int device)
{
// searching for the most items per block while respecting the shared
// memory limits creates a full linear programming problem.
// solving it in a single equation looks less tractable than this
Expand All @@ -109,10 +108,10 @@ struct forest {
ssp.cols_in_shmem = cols_in_shmem;
for (ssp.n_items = min_n_items; ssp.n_items <= max_n_items; ++ssp.n_items) {
ssp.shm_sz = dispatch_on_fil_template_params(compute_smem_footprint(), ssp);
if (ssp.shm_sz < max_shm) ssp_ = ssp;
if (ssp.shm_sz < max_shm_) ssp_ = ssp;
}
}
ASSERT(max_shm >= ssp_.shm_sz,
ASSERT(max_shm_ >= ssp_.shm_sz,
"FIL out of shared memory. Perhaps the maximum number of \n"
"supported classes is exceeded? 5'000 would still be safe.");
}
Expand Down Expand Up @@ -149,6 +148,7 @@ struct forest {

int device = h.get_device();
cudaStream_t stream = h.get_stream();
init_shmem_size(device);
init_n_items(device); // n_items takes priority over blocks_per_sm
init_fixed_block_count(device, params->blocks_per_sm);

Expand Down Expand Up @@ -325,11 +325,31 @@ struct forest {
float global_bias_ = 0;
shmem_size_params class_ssp_, proba_ssp_;
int fixed_block_count_ = 0;
int max_shm_ = 0;
// Optionally used
rmm::device_uvector<float> vector_leaf_;
cat_sets_device_owner cat_sets_;
};

template <typename storage_type>
struct opt_into_arch_dependent_shmem : dispatch_functor<void> {
const int max_shm;
opt_into_arch_dependent_shmem(int max_shm_) : max_shm(max_shm_) {}

template <typename KernelParams = KernelTemplateParams<>>
void run(predict_params p)
{
auto kernel = infer_k<KernelParams::N_ITEMS,
KernelParams::LEAF_ALGO,
KernelParams::COLS_IN_SHMEM,
KernelParams::CATS_SUPPORTED,
storage_type>;
// p.shm_sz might be > max_shm or < MAX_SHM_STD, but we should not check for either, because
// we don't run on both proba_ssp_ and class_ssp_ (only class_ssp_). This should be quick.
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shm));
}
};

struct dense_forest : forest {
dense_forest(const raft::handle_t& h) : forest(h), nodes_(0, h.get_stream()) {}

Expand Down Expand Up @@ -381,6 +401,10 @@ struct dense_forest : forest {
num_nodes * sizeof(dense_node),
cudaMemcpyHostToDevice,
h.get_stream()));

// predict_proba is a runtime parameter, and opt-in is unconditional
dispatch_on_fil_template_params(opt_into_arch_dependent_shmem<dense_storage>(max_shm_),
static_cast<predict_params>(class_ssp_));
// copy must be finished before freeing the host data
CUDA_CHECK(cudaStreamSynchronize(h.get_stream()));
h_nodes_.clear();
Expand Down Expand Up @@ -436,6 +460,10 @@ struct sparse_forest : forest {
nodes_.resize(num_nodes_, h.get_stream());
CUDA_CHECK(cudaMemcpyAsync(
nodes_.data(), nodes, sizeof(node_t) * num_nodes_, cudaMemcpyHostToDevice, h.get_stream()));

// predict_proba is a runtime parameter, and opt-in is unconditional
dispatch_on_fil_template_params(opt_into_arch_dependent_shmem<sparse_storage<node_t>>(max_shm_),
static_cast<predict_params>(class_ssp_));
}

virtual void infer(predict_params params, cudaStream_t stream) override
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/fil/internal.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,11 @@ void init_sparse(const raft::handle_t& h,
const fil_node_t* nodes,
const forest_params_t* params);

struct predict_params;

} // namespace fil

static const int MAX_SHM_STD = 48 * 1024; // maximum architecture-independent size

std::string output2str(fil::output_t output);
} // namespace ML
15 changes: 15 additions & 0 deletions cpp/test/sg/fil_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,21 @@ std::vector<FilTestParams> predict_dense_inputs = {
algo = BATCH_TREE_REORG,
leaf_algo = CATEGORICAL_LEAF,
num_classes = 3),
// use shared memory opt-in carveout if available, or infer out of L1 cache
FIL_TEST_PARAMS(num_rows = 103, num_cols = MAX_SHM_STD / sizeof(float) + 1024, algo = NAIVE),
FIL_TEST_PARAMS(num_rows = 103,
num_cols = MAX_SHM_STD / sizeof(float) + 1024,
leaf_algo = GROVE_PER_CLASS,
num_classes = 5),
FIL_TEST_PARAMS(num_rows = 103,
num_cols = MAX_SHM_STD / sizeof(float) + 1024,
num_trees = FIL_TPB + 1,
leaf_algo = GROVE_PER_CLASS,
num_classes = FIL_TPB + 1),
FIL_TEST_PARAMS(num_rows = 103,
num_cols = MAX_SHM_STD / sizeof(float) + 1024,
leaf_algo = CATEGORICAL_LEAF,
num_classes = 3),
FIL_TEST_PARAMS(algo = BATCH_TREE_REORG, threads_per_tree = 2),
FIL_TEST_PARAMS(algo = NAIVE, threads_per_tree = 4),
FIL_TEST_PARAMS(algo = TREE_REORG, threads_per_tree = 8),
Expand Down