Skip to content

Commit

Permalink
Add acc2smem in epilogue/threadblock/epilogue.h (NVIDIA#806)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkosaian authored Feb 7, 2023
1 parent 5921043 commit 5ff5209
Showing 1 changed file with 43 additions and 19 deletions.
62 changes: 43 additions & 19 deletions include/cutlass/epilogue/threadblock/epilogue.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,47 @@ class Epilogue :
operator()(output_op, destination_iterator, accumulators, SourceAspectNeeded(source_iterator));
}

template<class Seq>
struct acc2smem;

template <size_t... Seq>
struct acc2smem<cutlass::index_sequence<Seq...>> {
template<int Advance>
CUTLASS_DEVICE
static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator &warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}

CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename AccumulatorFragmentIterator::Fragment accum_fragment;

accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;

warp_tile_iterator.store(accum_fragment);
if (p < Base::kFragmentsPerIteration - 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
}
}

if (Base::kFragmentsPerIteration > 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
(1 - Base::kFragmentsPerIteration));
}
}

CUTLASS_DEVICE
static void push(size_t pos,
AccumulatorFragmentIterator const &iterator_begin,
WarpTileIterator &warp_tile_iterator) {
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
}
};


/// Streams the result to global memory
template <typename SourceAspect>
Expand Down Expand Up @@ -452,25 +493,8 @@ class Epilogue :

__syncthreads();

CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p)
{
typename AccumulatorFragmentIterator::Fragment accum_fragment;

accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;

this->warp_tile_iterator_.store(accum_fragment);

if (p < Base::kFragmentsPerIteration - 1) {
this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset);
}
}

if (Base::kFragmentsPerIteration > 1) {
this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
}

acc2smem<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::push(
iter, accum_fragment_iterator, this->warp_tile_iterator_);

//
// Load fragments from shared memory
Expand Down

0 comments on commit 5ff5209

Please sign in to comment.