Skip to content

Enable NaN checks on tensor arguments to kernel launches #4029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions fbgemm_gpu/codegen/genscript/jinja_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,30 @@ def make_pta_acc_format(pta_str_list: List[str], func_name: str) -> List[str]:
return new_str_list


def make_pta_acc_builder_format(pta_str_list: List[str]) -> List[str]:
new_str_list = []
for pta_str in pta_str_list:
if "packed_accessor" in pta_str:
match = re.search(
r"([a-zA-z0-9_]*)[.]packed_accessor([3|6][2|4])<(.*)>\(\)", pta_str
)
assert match is not None and len(match.groups()) == 3
tensor, acc_nbits, args = match.groups()
if "acc_type" in args:
match = re.search("at::acc_type<([a-zA-Z_0-9]*), true>", args)
assert match is not None and len(match.groups()) == 1
new_type = match.group(1)
args = re.sub("at::acc_type<[a-zA-Z_]*, true>", new_type, args)
macro_name = "PTA_ACC_B"
else:
macro_name = "PTA_B"
args = args.replace(", at::RestrictPtrTraits", "")
new_str_list.append(f"{macro_name}({tensor}, {args}, {acc_nbits})")
else:
new_str_list.append(pta_str)
return new_str_list


def replace_pta_namespace(pta_str_list: List[str]) -> List[str]:
return [
pta_str.replace("at::PackedTensorAccessor", "pta::PackedTensorAccessor")
Expand Down Expand Up @@ -431,6 +455,7 @@ def to_upper_placeholder_types(arg_str_list: List[str]) -> List[str]:
################################################################################

env.filters["make_pta_acc_format"] = make_pta_acc_format
env.filters["make_pta_acc_builder_format"] = make_pta_acc_builder_format
env.filters["replace_pta_namespace"] = replace_pta_namespace
env.filters["replace_placeholder_types"] = replace_placeholder_types
env.filters["to_upper_placeholder_types"] = to_upper_placeholder_types
309 changes: 149 additions & 160 deletions fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// clang-format off
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
#include "fbgemm_gpu/utils/kernel_launcher.cuh"

using Tensor = at::Tensor;
using namespace fbgemm_gpu;
Expand Down Expand Up @@ -172,48 +173,43 @@ void split_embedding_{{ optimizer }}_update(
#else
constexpr int kThreadGroupSize = kWarpSize;
#endif
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "split_{{ optimizer }}_update_kernel";
#endif

DISPATCH_PLACEHOLDER_TYPES(
{%- for ph_name in args.placeholder_tensor_names %}
{{ ph_name + "_dev" }}.scalar_type(),
{%- endfor %}
"split_embedding_{{ optimizer }}_update_placeholder_type_kernel",
[&] {
split_{{ optimizer }}_update_kernel<
emb_t,
cache_t,
{%- for ph_name in args.placeholder_tensor_names %}
{{ ph_name + "_ph_t" }},
{%- endfor %}
kMaxVecsPerThread,
kThreadGroupSize,
4>
<<<div_round_up(grad_dev_indices.numel(), kMaxThreads / kThreadGroupSize),
dim3(kThreadGroupSize, kMaxThreads / kThreadGroupSize, 1),
0, // Shared memory is not needed because uint8_t is not supported
at::cuda::getCurrentCUDAStream()
>>>
(
MAKE_PTA_WITH_NAME(func_name, dev_weights, emb_t, 1, 64),
MAKE_PTA_WITH_NAME(func_name, uvm_weights, emb_t, 1, 64),
MAKE_PTA_WITH_NAME(func_name, lxu_cache_weights, cache_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, flatten_grad_dev_weights, emb_t, 1, 64),
MAKE_PTA_WITH_NAME(func_name, flatten_grad_dev_indices, int64_t, 1, 64),
MAKE_PTA_WITH_NAME(func_name, weights_placements, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32),
// Use weights_placements instead of
// sorted_lxu_cache_locations because LXU cache is not
// supported right now
MAKE_PTA_WITH_NAME(func_name, weights_placements, int32_t, 1, 32),
max_D,
stochastic_rounding,
rng_engine_inputs,
{{ args.split_kernel_arg_constructors | make_pta_acc_format("func_name") | join(", ") }}
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
FBGEMM_LAUNCH_KERNEL(
(split_{{ optimizer }}_update_kernel<
emb_t,
cache_t,
{%- for ph_name in args.placeholder_tensor_names %}
{{ ph_name + "_ph_t" }},
{%- endfor %}
kMaxVecsPerThread,
kThreadGroupSize,
4>),
div_round_up(grad_dev_indices.numel(), kMaxThreads / kThreadGroupSize),
dim3(kThreadGroupSize, kMaxThreads / kThreadGroupSize, 1),
0, // Shared memory is not needed because uint8_t is not supported
at::cuda::getCurrentCUDAStream(),
PTA_B(dev_weights, emb_t, 1, 64),
PTA_B(uvm_weights, emb_t, 1, 64),
PTA_B(lxu_cache_weights, cache_t, 2, 64),
PTA_B(flatten_grad_dev_weights, emb_t, 1, 64),
PTA_B(flatten_grad_dev_indices, int64_t, 1, 64),
PTA_B(weights_placements, int32_t, 1, 32),
PTA_B(weights_offsets, int64_t, 1, 32),
// Use weights_placements instead of
// sorted_lxu_cache_locations because LXU cache is not
// supported right now
PTA_B(weights_placements, int32_t, 1, 32),
max_D,
stochastic_rounding,
rng_engine_inputs,
{{ args.split_kernel_arg_constructors | make_pta_acc_builder_format() | join(", ") }}
);
}); // DISPATCH_PLACEHOLDER_TYPES
return;
}
Expand Down
103 changes: 74 additions & 29 deletions fbgemm_gpu/include/fbgemm_gpu/utils/kernel_launcher.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,35 @@ decltype(auto) transform_kernel_arg(const SourceContext& context, T&& arg) {
}
}

////////////////////////////////////////////////////////////////////////////////
// Verify Kernel Argument
//
// Verify certain arguments before and after kernel invocation
////////////////////////////////////////////////////////////////////////////////

template <typename T>
decltype(auto) check_kernel_arg(const SourceContext& context, T&& arg) {
if constexpr (is_tensor_accessor_builder_v<std::decay_t<T>>) {
// If the arg is a TensorAccessorBuilder, run verifications on the tensor it
// is ref-wrapping, e.g. NaN value checks.
return arg.checkValues(context.description());
} else {
// Otherwise, perfect-forward the argument as is
return std::forward<T>(arg);
}
}

////////////////////////////////////////////////////////////////////////////////
// GPU Kernel Launcher
//
// This class encapsulates the common ceremonial pre- and post-execution
// routines when launching GPU kernels.
////////////////////////////////////////////////////////////////////////////////

template <bool EnableDSA = false, bool EnableBarrierIsolation = false>
template <
bool EnableDSA = false,
bool EnableBarrierIsolation = false,
bool EnableNaNChecks = false>
struct KernelLauncher {
const SourceContext context;

Expand Down Expand Up @@ -234,6 +255,21 @@ struct KernelLauncher {
// device associated with the compute stream
checkSharedMemoryPerBlockNotExceeded(properties, shared_mem_per_block);

// If NaN checks are enabled, run verifications on all kernel arguments that
// are tensors
if constexpr (EnableNaNChecks) {
const auto summary = std::string(context.summary) + " (pre-execution)";
(check_kernel_arg(context.withSummary(summary), std::forward<Args>(args)),
...);
}

// If barrier isolation is enabled, synchronize the stream first before
// launching the kernel. This has roughly the same effect as setting
// `CUDA_LAUNCH_BLOCKING=1` as an environment variable.
if constexpr (EnableBarrierIsolation) {
cudaDeviceSynchronize();
}

if constexpr (EnableDSA) {
// This launch code here is essentially the same as the contents of
// TORCH_USE_CUDA_DSA macro, but with the addition of kernel argument
Expand All @@ -251,13 +287,6 @@ struct KernelLauncher {
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref();
#endif

// If barrier isolation is enabled, synchronize the stream first before
// launching the kernel. This has roughly the same effect as setting
// `CUDA_LAUNCH_BLOCKING=1` as an environment variable.
if constexpr (EnableBarrierIsolation) {
cudaDeviceSynchronize();
}

// Launch the kernel
kernel<<<grid, block, shared_mem_per_block, stream>>>(
// Transform arguments to the kernel before forwarding them.
Expand Down Expand Up @@ -285,6 +314,14 @@ struct KernelLauncher {

// Check for CUDA errors
C10_CUDA_KERNEL_LAUNCH_CHECK();

// If NaN checks are enabled, run post-kernel verifications on all kernel
// arguments that are tensors
if constexpr (EnableNaNChecks) {
const auto summary = std::string(context.summary) + " (post-execution)";
(check_kernel_arg(context.withSummary(summary), std::forward<Args>(args)),
...);
}
}
};

Expand Down Expand Up @@ -320,30 +357,38 @@ struct KernelLauncher {
#define _FKL_TFILE_ ""
#endif

#ifdef FBGEMM_GPU_KERNEL_DEBUG
#define _FKL_KDEBUG_ true
#ifdef FBGEMM_GPU_ISOLATE_KERNEL_LAUNCH
#define _FKL_BLOCKING_ true
#else
#define _FKL_BLOCKING_ false
#endif

#ifdef FBGEMM_GPU_TENSORCHECK
#define _FKL_TENSORCHECK_ true
#else
#define _FKL_KDEBUG_ false
#define _FKL_TENSORCHECK_ false
#endif

#define FBGEMM_LAUNCH_KERNEL(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
([&] { \
using source_location = fbgemm_gpu::utils::source_location; \
constexpr auto location = source_location::current(); \
decltype(KERNEL)& kernel = KERNEL; \
\
return fbgemm_gpu::utils::KernelLauncher<false, _FKL_KDEBUG_>( \
location, #KERNEL, _FKL_TFILE_) \
.launch_kernel(kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
#define FBGEMM_LAUNCH_KERNEL(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
([&] { \
using source_location = fbgemm_gpu::utils::source_location; \
constexpr auto location = source_location::current(); \
decltype(KERNEL)& kernel = KERNEL; \
\
return fbgemm_gpu::utils:: \
KernelLauncher<false, _FKL_BLOCKING_, _FKL_TENSORCHECK_>( \
location, #KERNEL, _FKL_TFILE_) \
.launch_kernel(kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
}())

#define FBGEMM_LAUNCH_DSA_KERNEL(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
([&] { \
using source_location = fbgemm_gpu::utils::source_location; \
constexpr auto location = source_location::current(); \
decltype(KERNEL)& kernel = KERNEL; \
\
return fbgemm_gpu::utils::KernelLauncher<true, _FKL_KDEBUG_>( \
location, #KERNEL, _FKL_TFILE_) \
.launch_kernel(kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
#define FBGEMM_LAUNCH_DSA_KERNEL(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
([&] { \
using source_location = fbgemm_gpu::utils::source_location; \
constexpr auto location = source_location::current(); \
decltype(KERNEL)& kernel = KERNEL; \
\
return fbgemm_gpu::utils:: \
KernelLauncher<true, _FKL_BLOCKING_, _FKL_TENSORCHECK_>( \
location, #KERNEL, _FKL_TFILE_) \
.launch_kernel(kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
}())
5 changes: 5 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/source_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ struct SourceContext {

return *desc_;
}

inline SourceContext withSummary(
const std::string_view& sum_) const noexcept {
return SourceContext(location, sum_, secondaryLocation);
}
};

} // namespace fbgemm_gpu::utils
20 changes: 20 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,26 @@ struct TensorAccessorBuilder {
return build_ta(context);
}
}

//////////////////////////////////////////////////////////////////////////////
// Check Tensor values for NaN
//////////////////////////////////////////////////////////////////////////////

C10_ALWAYS_INLINE void checkValues(const std::string_view& context) const {
TORCH_CHECK(
!at::isnan(tensor).any().item<bool>(),
context,
": Tensor '",
name,
"' contains NaN values!");

TORCH_CHECK(
!at::isinf(tensor).any().item<bool>(),
context,
": Tensor '",
name,
"' contains (+/-) Inf values!");
}
};

} // namespace fbgemm_gpu::utils
Expand Down
Loading
Loading