We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
commit_smem_to_gmem_group
emit_pipeline
1 parent 7f5b6e7 commit af66ca9Copy full SHA for af66ca9
jax/_src/pallas/mosaic_gpu/pipeline.py
@@ -326,7 +326,8 @@ def loop_body(step, carry):
326
predicate=lax.bitwise_or(slices_changed, is_last_step),
327
)
328
329
- gpu_primitives.commit_smem_to_gmem_group()
+ if copies_out_in_loop:
330
+ gpu_primitives.commit_smem_to_gmem_group()
331
332
fetch_step = step + (max_concurrent_steps - delay_release)
333
fetch_slot = lax.rem(fetch_step, max_concurrent_steps)
@@ -367,6 +368,7 @@ def do_fetch():
367
368
# loop. This is the only place where we store them.
369
if not copies_out_in_loop:
370
gpu_primitives.commit_smem()
371
+
372
last_slot = lax.rem(num_steps - 1, max_concurrent_steps)
373
for bref in out_brefs:
374
if bref.is_index_invariant:
0 commit comments