- 
                Notifications
    You must be signed in to change notification settings 
- Fork 290
Closed
Description
I ran the gqa_bwd in tilelang and replaced the ref_program with the flash-attn3 func, and found the time consumption of tilelang is much larger than the fa3 baseline. I found atomic_add is used instead of the two-stage strategy used in the Triton version. Is it the reason?
$ python example_gqa_bwd.py 
2025-09-30 14:25:08  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `flash_fwd` with `out_idx=[3, 4]`
2025-09-30 14:25:21  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `flash_fwd`
2025-09-30 14:25:21  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `flash_bwd_prep` with `out_idx=[2]`
2025-09-30 14:25:29  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `flash_bwd_prep`
2025-09-30 14:25:29  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `flash_bwd_post` with `out_idx=[1]`
2025-09-30 14:25:38  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `flash_bwd_post`
2025-09-30 14:25:38  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `flash_bwd` with `out_idx=None`
2025-09-30 14:25:59  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `flash_bwd`
torch: 1.47 ms
torch: 303.01 TFlops
tilelang: 8.74 ms
tilelang: 51.09 TFlops
Rachmanino
Metadata
Metadata
Assignees
Labels
No labels