From 06b21349bcf6ddf6a1686a47a137ad1446579db9 Mon Sep 17 00:00:00 2001 From: dePaul Miller Date: Thu, 1 Aug 2024 09:20:28 -0700 Subject: [PATCH] 1x1x1 cluster launch (#1673) --- .../cutlass/conv/device/conv_universal_adapter.hpp | 14 +++++++++++--- .../cutlass/gemm/device/gemm_universal_adapter.h | 11 +++++++++-- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/include/cutlass/conv/device/conv_universal_adapter.hpp b/include/cutlass/conv/device/conv_universal_adapter.hpp index 9812937e2e..0472b898c2 100644 --- a/include/cutlass/conv/device/conv_universal_adapter.hpp +++ b/include/cutlass/conv/device/conv_universal_adapter.hpp @@ -296,7 +296,9 @@ class ConvUniversalAdapter Status launch_result; // Use extended launch API only for mainloops that use it - if constexpr(ConvKernel::ArchTag::kMinComputeCapability >= 90) { + if constexpr (ConvKernel::ArchTag::kMinComputeCapability >= 90) { + constexpr bool is_static_1x1x1 = cute::is_static_v and + cute::size(typename ConvKernel::DispatchPolicy::ClusterShape{}) == 1; dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}), cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}), cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})); @@ -324,8 +326,14 @@ class ConvUniversalAdapter CUTLASS_ASSERT(cuda_adapter == nullptr); void const* kernel = (void const*) device_kernel; if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90) { - launch_result = ClusterLauncher::launch( - grid, cluster, block, smem_size, stream, kernel, kernel_params); + if constexpr (is_static_1x1x1) { + device_kernel<<>>(params); + launch_result = Status::kSuccess; + } + else { + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params); + } } } } diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index ce7fd3203d..40094dcb10 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -350,6 +350,8 @@ class GemmUniversalAdapter< Status launch_result{ Status::kSuccess }; // Use extended launch API only for mainloops that use it if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) { + constexpr bool is_static_1x1x1 = cute::is_static_v and + cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); @@ -383,8 +385,13 @@ class GemmUniversalAdapter< CUTLASS_ASSERT(cuda_adapter == nullptr); void const* kernel = (void const*) device_kernel; if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) { - launch_result = ClusterLauncher::launch( - grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); + if (is_static_1x1x1 && not launch_with_pdl) { + device_kernel<<>>(params); + } + else { + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); + } } } }