Skip to content

Commit

Permalink
Stream-K with broadcast (NVIDIA#892)
Browse files Browse the repository at this point in the history
* [WIP] GEMM StreamK w/ Fused Epilogue

* Adds Gemm Streamk with Fused Epilogue kernel level struct.
  * Mostly based on Gemm with Fused Epilogue,
  * Requires a new epilogue
  * Work in progress

* [WIP] StreamK support for GemmUniversalWithBroadcast

* Just based off of how StreamK is allowed in GemmUniversal
  * Untested and a work in progress

* Minor fixes

* [WIP] It compiles!

It is almost certainly incorrect, but we're past getting the templates
to match, so checkpointing.

* Correction to reference kernel

* Fix typo

* Added MSE measurement

* Switch back to reference kernel + host for loop

Still WIP. Now we're getting even a larger MSE, but it's both on
basic Split-K and Stream-K.

* Fix typos

* Fix broadcast vector + requested changes

* Comment typo

* Small int option and more

* Fix incorrect condition on source needed

* Requested changes

* I think I got it?

* Bias vector should be stride 0

* Two source added!

* Typos

* Merge examples

* Bring back vector row offset

Just to ensure consistency with universal gemm with fused epilogue

* Base arguments and params structs for StreamK

* StreamK epilogue with broadcast now inherits the original

* undo params_streamk_base.h

---------

Co-authored-by: Ali Hassani <ahassanijr@gmail.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
  • Loading branch information
3 people authored May 22, 2023
1 parent 6fbc0d3 commit 13f4134
Show file tree
Hide file tree
Showing 9 changed files with 4,285 additions and 3 deletions.
4 changes: 4 additions & 0 deletions examples/47_ampere_gemm_universal_streamk/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ cutlass_example_add_executable(
ampere_gemm_universal_streamk.cu
)

cutlass_example_add_executable(
47_ampere_gemm_universal_streamk_broadcast
ampere_gemm_universal_streamk_broadcast.cu
)
Original file line number Diff line number Diff line change
Expand Up @@ -495,23 +495,23 @@ int main(int argc, const char **argv)
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel

// Fill matrix A on host with uniform-random data [2, -2]
// Fill matrix A on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_a.host_view(),
1,
ElementA(2),
ElementA(-2),
0);

// Fill matrix B on host with uniform-random data [2, -2]
// Fill matrix B on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_b.host_view(),
1,
ElementB(2),
ElementB(-2),
0);

// Fill matrix C on host with uniform-random data [2, -2]
// Fill matrix C on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_c.host_view(),
1,
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h"
#include "cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h"

#include "cutlass/layout/permute.h"

Expand Down Expand Up @@ -120,6 +121,67 @@ struct DefaultEpilogueWithBroadcastTensorOp {

////////////////////////////////////////////////////////////////////////////////

/// Defines sensible defaults for streamk epilogues for TensorOps.
template <
typename Shape,
typename WarpMmaTensorOp,
int PartitionsK,
typename ElementOutput,
typename ElementTensor,
typename ElementVector,
typename OutputOp,
int ElementsPerAccess,
bool ScatterD = false,
typename PermuteDLayout = layout::NoPermute
>
struct DefaultStreamkEpilogueWithBroadcastTensorOp {

/// Use defaults related to the existing epilogue
using Base = DefaultEpilogueTensorOp<
Shape,
WarpMmaTensorOp,
PartitionsK,
OutputOp,
ElementsPerAccess
>;

//
// Stores the result z = (y = GEMM(A, B, C), broadcast)
//
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
typename Base::OutputTileThreadMap,
ElementOutput,
ScatterD,
PermuteDLayout
>;

//
// Additional tensor tile iterator - stores t = Elementwise(z)
//
using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
typename Base::OutputTileThreadMap,
ElementTensor
>;

/// Define the epilogue
using Epilogue = EpilogueStreamkWithBroadcast<
Shape,
WarpMmaTensorOp,
PartitionsK,
OutputTileIterator,
TensorTileIterator,
ElementVector,
typename Base::AccumulatorFragmentIterator,
typename Base::WarpTileIterator,
typename Base::SharedLoadIterator,
OutputOp,
typename Base::Padding,
Base::kFragmentsPerIteration
>;
};

////////////////////////////////////////////////////////////////////////////////

/// Defines sensible defaults for epilogues for VoltaTensorOps.
template <
typename Shape,
Expand Down
Loading

0 comments on commit 13f4134

Please sign in to comment.