Skip to content

Commit 16b61eb

Browse files
xw285cornellfacebook-github-bot
authored andcommitted
Force determinism by unswizzle (#3727)
Summary: X-link: facebookresearch/FBGEMM#810 Thanks to yifuwang's suggestion, unswizzle to get the same addition order across ranks Reviewed By: yifuwang Differential Revision: D69696369
1 parent de35b3c commit 16b61eb

File tree

1 file changed

+2
-2
lines changed
  • fbgemm_gpu/experimental/gen_ai/src/comm

1 file changed

+2
-2
lines changed

fbgemm_gpu/experimental/gen_ai/src/comm/car.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ __launch_bounds__(512)
146146
bf16x8 vals[kWorldSize];
147147
#pragma unroll kWorldSize
148148
for (int ii = 0; ii < kWorldSize; ++ii) {
149-
*reinterpret_cast<uint4*>(&vals[ii]) =
149+
*reinterpret_cast<uint4*>(&vals[(ii + kWorldSize - rank) % kWorldSize]) =
150150
reinterpret_cast<const uint4*>(&src_d[ii][i])[0];
151151
}
152152

@@ -510,7 +510,7 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
510510
bf16x8 vals[kWorldSize];
511511
#pragma unroll kWorldSize
512512
for (int ii = 0; ii < kWorldSize; ++ii) {
513-
*reinterpret_cast<uint4*>(&vals[ii]) =
513+
*reinterpret_cast<uint4*>(&vals[(ii + kWorldSize - rank) % kWorldSize]) =
514514
reinterpret_cast<const uint4*>(&src_d[ii][i + N_start])[0];
515515
}
516516

0 commit comments

Comments
 (0)