Skip to content

Commit

Permalink
More updates for 3.1 (NVIDIA#958)
Browse files Browse the repository at this point in the history
* Updates for 3.1

* Minor change

* doc link fix

* Minor updates
  • Loading branch information
ANIKET-SHIVAM authored May 24, 2023
1 parent 13f4134 commit f079619
Show file tree
Hide file tree
Showing 48 changed files with 1,574 additions and 1,821 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14)
* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](/python/README.md) and new [examples](/examples/python).
* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) for FP16 datatype using TMA for Hopper.
* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
* Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
* New [*warp-specialized persistent cooperative*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper.
Expand All @@ -12,6 +12,11 @@
* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
* Performance optimizations for the [*warp-specialized persistent ping-pong*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel.
* Changes to the [GEMM API 3.x](media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
* [FMHA Backward Pass](examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
* [Streamk GEMM with Broadcast](examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
* [Batched B2B GEMM](examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
* [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
* [Permute + GEMM fusion](examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.
* The GitHub branch is renamed from `master` to `main` in this release.
* Optimal performance using [**CUDA 12.1**](https://developer.nvidia.com/cuda-downloads)
* Updates and bugfixes from the community (thanks!)
Expand Down
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,20 @@ In addition to GEMMs, CUTLASS implements high-performance convolution via the im
CUTLASS 3.1 is an update to CUTLASS adding:

- New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](/python/README.md) and new [examples](/examples/python).
- New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) for FP16 datatype using TMA for Hopper.
- New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper.
- Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues.
- New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
- New [*warp-specialized persistent cooperative*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that improves performance on Hopper.
- An [example](examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper.
- New Epilogue builders. Similar to mainloop builders (see [example 49](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization.
- Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler.
- Changes to the [GEMM API 3.x](media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs.
- [FMHA Backward Pass](examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers.
- [Streamk GEMM with Broadcast](examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM.
- [Batched B2B GEMM](examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel.
- [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix.
- [Permute + GEMM fusion](examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue.

- *Announcement*:
- The GitHub branch is renamed from `master` to `main` in this release.
- A slight modification has been made to the ordering of arguments passed in to epilogues in 3.x kernels.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,11 @@ class B2bMmaMultistage :

}

// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();

// 2nd Gemm

/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
Expand Down Expand Up @@ -871,7 +876,10 @@ class B2bMmaMultistage :

}


// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();

}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,11 @@ class B2bMmaMultistageSmemAccumulator :

}

// Insert fence and wait for all outstanding cp.async operations to commit.
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();

/// Epilogue for the first Implicit Gemm
Epilogue0 epilogue0;

Expand Down Expand Up @@ -855,7 +860,10 @@ class B2bMmaMultistageSmemAccumulator :

}


// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();

}
};
Expand Down
11 changes: 4 additions & 7 deletions examples/45_dual_gemm/threadblock/dual_mma_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -759,13 +759,10 @@ class DualMmaMultistage :
accum1 = plus_accum(accum1, tmp_accum1);
}

if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}

// commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,6 @@ Result run(std::string description, Options &options)
std::cout << " GFLOPs: " << result.gflops << std::endl;
}

// TODO: uncomment when results match
//if (!result.passed) {
// exit(-1);
//}

return result;
}

Expand Down
2 changes: 1 addition & 1 deletion include/cute/algorithm/tuple_algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ flatten(T const& t)

namespace detail {

// Shortcut around tuple_cat for common insert/remove/repeat cases
// Shortcut around cute::tuple_cat for common insert/remove/repeat cases
template <class T, class X, int... I, int... J, int... K>
CUTE_HOST_DEVICE constexpr
auto
Expand Down
6 changes: 3 additions & 3 deletions include/cute/atom/mma_atom.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ partition_shape_C(TiledMMA<Args...> const& mma, Shape_MN const& shape_MN)
auto V = shape<1>(typename TiledMMA<Args...>::AtomLayoutC_TV{});
auto M = shape_div(size<0>(shape_MN), size<0>(atomMNK) * size<1>(thrVMNK));
auto N = shape_div(size<1>(shape_MN), size<1>(atomMNK) * size<2>(thrVMNK));
return tuple_cat(make_shape(V,M,N), take<2,R>(shape_MN));
return cute::tuple_cat(make_shape(V,M,N), take<2,R>(shape_MN));
}

template <class... Args, class Shape_MN>
Expand Down Expand Up @@ -651,7 +651,7 @@ partition_shape_A(TiledMMA<Args...> const& mma, Shape_MK const& shape_MK)
auto V = shape<1>(typename TiledMMA<Args...>::AtomLayoutA_TV{});
auto M = shape_div(size<0>(shape_MK), size<0>(atomMNK) * size<1>(thrVMNK));
auto K = shape_div(size<1>(shape_MK), size<2>(atomMNK) * size<3>(thrVMNK));
return tuple_cat(make_shape(V,M,K), take<2,R>(shape_MK));
return cute::tuple_cat(make_shape(V,M,K), take<2,R>(shape_MK));
}

template <class... Args, class Shape_NK>
Expand All @@ -666,7 +666,7 @@ partition_shape_B(TiledMMA<Args...> const& mma, Shape_NK const& shape_NK)
auto V = shape<1>(typename TiledMMA<Args...>::AtomLayoutB_TV{});
auto N = shape_div(size<0>(shape_NK), size<1>(atomMNK) * size<2>(thrVMNK));
auto K = shape_div(size<1>(shape_NK), size<2>(atomMNK) * size<3>(thrVMNK));
return tuple_cat(make_shape(V,N,K), take<2,R>(shape_NK));
return cute::tuple_cat(make_shape(V,N,K), take<2,R>(shape_NK));
}

//
Expand Down
18 changes: 15 additions & 3 deletions include/cute/container/cuda_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,14 @@ namespace cute

using dim3 = ::dim3;

// MSVC doesn't define its C++ version macro to match
// its C++ language version. This means that when
// building with MSVC, dim3 isn't constexpr-friendly.
template <size_t I>
CUTE_HOST_DEVICE constexpr
CUTE_HOST_DEVICE
#if ! defined(_MSC_VER)
constexpr
#endif
uint32_t& get(dim3& a)
{
static_assert(I < 3, "Index out of range");
Expand All @@ -63,7 +69,10 @@ uint32_t& get(dim3& a)
}

template <size_t I>
CUTE_HOST_DEVICE constexpr
CUTE_HOST_DEVICE
#if ! defined(_MSC_VER)
constexpr
#endif
uint32_t const& get(dim3 const& a)
{
static_assert(I < 3, "Index out of range");
Expand All @@ -79,7 +88,10 @@ uint32_t const& get(dim3 const& a)
}

template <size_t I>
CUTE_HOST_DEVICE constexpr
CUTE_HOST_DEVICE
#if ! defined(_MSC_VER)
constexpr
#endif
uint32_t&& get(dim3&& a)
{
static_assert(I < 3, "Index out of range");
Expand Down
Loading

0 comments on commit f079619

Please sign in to comment.