Skip to content

Commit

Permalink
fix epilogue iterator error (NVIDIA#995)
Browse files Browse the repository at this point in the history
* fix epilogue iterator error

* fix epilogue iterator error

---------

Co-authored-by: maxiao <maxiao@cowarobot.com>
  • Loading branch information
ChangyouSiom and maxiao authored Jul 11, 2023
1 parent 9b923dd commit e066ced
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions include/cutlass/epilogue/threadblock/predicated_tile_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,16 @@ class PredicatedTileIterator {
}

if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
if (!ScatterD && !PermuteD) {
byte_pointer += params_.increment_group;
}
}
}

if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
if (!ScatterD && !PermuteD) {
byte_pointer += params_.increment_cluster;
}
}
}
}
Expand Down Expand Up @@ -650,8 +654,12 @@ class PredicatedTileIterator {

state_[0] = 0;
++state_[1];
byte_pointer_ += params_.advance_group;
store_byte_pointer_ += params_.advance_group;
if (!ScatterD) {
byte_pointer_ += params_.advance_group;
}
if (!ScatterD && !PermuteD) {
store_byte_pointer_ += params_.advance_group;
}

thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
Expand All @@ -660,16 +668,24 @@ class PredicatedTileIterator {

state_[1] = 0;
++state_[2];
byte_pointer_ += params_.advance_cluster;
store_byte_pointer_ += params_.advance_cluster;
if (!ScatterD) {
byte_pointer_ += params_.advance_cluster;
}
if (!ScatterD && !PermuteD) {
store_byte_pointer_ += params_.advance_cluster;
}

thread_start_row_ += ThreadMap::Count::kGroup *
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;

if (state_[2] == ThreadMap::Count::kCluster) {
state_[2] = 0;
byte_pointer_ += params_.advance_tile;
store_byte_pointer_ += params_.advance_tile;
if (!ScatterD) {
byte_pointer_ += params_.advance_tile;
}
if (!ScatterD && !PermuteD) {
store_byte_pointer_ += params_.advance_tile;
}

thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow
* ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile;
Expand Down

0 comments on commit e066ced

Please sign in to comment.