Skip to content

Commit af66ca9

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Do not do unnecessary commit_smem_to_gmem_group in emit_pipeline
If the loop does no copies, there is nothing to commit. PiperOrigin-RevId: 758660634
1 parent 7f5b6e7 commit af66ca9

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

jax/_src/pallas/mosaic_gpu/pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,8 @@ def loop_body(step, carry):
326326
predicate=lax.bitwise_or(slices_changed, is_last_step),
327327
)
328328

329-
gpu_primitives.commit_smem_to_gmem_group()
329+
if copies_out_in_loop:
330+
gpu_primitives.commit_smem_to_gmem_group()
330331

331332
fetch_step = step + (max_concurrent_steps - delay_release)
332333
fetch_slot = lax.rem(fetch_step, max_concurrent_steps)
@@ -367,6 +368,7 @@ def do_fetch():
367368
# loop. This is the only place where we store them.
368369
if not copies_out_in_loop:
369370
gpu_primitives.commit_smem()
371+
370372
last_slot = lax.rem(num_steps - 1, max_concurrent_steps)
371373
for bref in out_brefs:
372374
if bref.is_index_invariant:

0 commit comments

Comments
 (0)