Skip to content

Commit

Permalink
Skip void-C kernels in the profiler when beta is non zero (NVIDIA#1661)
Browse files Browse the repository at this point in the history
* Skip void-C kernels in the profiler when beta is non zero

CUTLASS profiler will only skip disposition for void-C kernels when beta
is non zero, when it makes more sense to skip running it in the first
place.

Not all users are aware of void-C kernels (as far as I know it wasn't a
thing in 2.X), and not everyone remembers to filter out voidC kernels
when running the profiler with a non zero beta.

The easiest solution (and as far as I can tell correct way of handling this)
is that `can_implement` return `false` when beta is non zero (or
whatever argument indicates an epilogue source) but we have a void-C
kernel.

Profiler already includes functionality to skip running kernels that
fail `can_implement`.

* Move checks to collectives instead

---------

Co-authored-by: Ali Hassani <ahassani@nvidia.com>
  • Loading branch information
alihassanijr and Ali Hassani authored Jul 31, 2024
1 parent 8b2a040 commit 1f2b590
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 2 deletions.
21 changes: 21 additions & 0 deletions include/cutlass/epilogue/collective/detail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,27 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp {
tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) { }
};

// SFINAE helpers for detecting beta/beta_ptr in EVT arguments.
template <class Arguments, class = void>
struct has_beta {
static constexpr bool value = false;
};

template <class Arguments>
struct has_beta<Arguments, cute::void_t<decltype(Arguments{}.thread.beta)>> {
static constexpr bool value = true;
};

template <class Arguments, class = void>
struct has_beta_ptr {
static constexpr bool value = false;
};

template <class Arguments>
struct has_beta_ptr<Arguments, cute::void_t<decltype(Arguments{}.thread.beta_ptr)>> {
static constexpr bool value = true;
};

} // namespace detail
} // namespace collective
} // namespace epilogue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,23 @@ class CollectiveEpilogue<
if (!fusion_implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n");
}
return implementable && fusion_implementable;

bool beta_implementable = true;

if constexpr (cute::is_void_v<ElementC>) {
if constexpr (detail::has_beta<Arguments>::value) {
beta_implementable = args.thread.beta == 0.0;
}
if constexpr (detail::has_beta_ptr<Arguments>::value) {
beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr;
}
}

if (!beta_implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n");
}

return implementable && fusion_implementable && beta_implementable;
}

template<class TileShapeMNK>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,23 @@ class CollectiveEpilogue<
if (!fusion_implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n");
}
return implementable && fusion_implementable;

bool beta_implementable = true;

if constexpr (cute::is_void_v<ElementC>) {
if constexpr (detail::has_beta<Arguments>::value) {
beta_implementable = args.thread.beta == 0.0;
}
if constexpr (detail::has_beta_ptr<Arguments>::value) {
beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr;
}
}

if (!beta_implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n");
}

return implementable && fusion_implementable && beta_implementable;
}

template<class TileShapeMNK>
Expand Down

0 comments on commit 1f2b590

Please sign in to comment.