Skip to content

Commit

Permalink
1x1x1 cluster launch (NVIDIA#1673)
Browse files Browse the repository at this point in the history
  • Loading branch information
depaulmillz authored Aug 1, 2024
1 parent eee0cab commit 06b2134
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
14 changes: 11 additions & 3 deletions include/cutlass/conv/device/conv_universal_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,9 @@ class ConvUniversalAdapter

Status launch_result;
// Use extended launch API only for mainloops that use it
if constexpr(ConvKernel::ArchTag::kMinComputeCapability >= 90) {
if constexpr (ConvKernel::ArchTag::kMinComputeCapability >= 90) {
constexpr bool is_static_1x1x1 = cute::is_static_v<typename ConvKernel::DispatchPolicy::ClusterShape> and
cute::size(typename ConvKernel::DispatchPolicy::ClusterShape{}) == 1;
dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}),
cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{}));
Expand Down Expand Up @@ -324,8 +326,14 @@ class ConvUniversalAdapter
CUTLASS_ASSERT(cuda_adapter == nullptr);
void const* kernel = (void const*) device_kernel<ConvKernel>;
if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90) {
launch_result = ClusterLauncher::launch(
grid, cluster, block, smem_size, stream, kernel, kernel_params);
if constexpr (is_static_1x1x1) {
device_kernel<ConvKernel><<<grid, block, smem_size, stream>>>(params);
launch_result = Status::kSuccess;
}
else {
launch_result = ClusterLauncher::launch(
grid, cluster, block, smem_size, stream, kernel, kernel_params);
}
}
}
}
Expand Down
11 changes: 9 additions & 2 deletions include/cutlass/gemm/device/gemm_universal_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,8 @@ class GemmUniversalAdapter<
Status launch_result{ Status::kSuccess };
// Use extended launch API only for mainloops that use it
if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) {
constexpr bool is_static_1x1x1 = cute::is_static_v<typename GemmKernel::DispatchPolicy::ClusterShape> and
cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1;
dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{}));
Expand Down Expand Up @@ -383,8 +385,13 @@ class GemmUniversalAdapter<
CUTLASS_ASSERT(cuda_adapter == nullptr);
void const* kernel = (void const*) device_kernel<GemmKernel>;
if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) {
launch_result = ClusterLauncher::launch(
grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl);
if (is_static_1x1x1 && not launch_with_pdl) {
device_kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params);
}
else {
launch_result = ClusterLauncher::launch(
grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl);
}
}
}
}
Expand Down

0 comments on commit 06b2134

Please sign in to comment.