Skip to content

Regarding the bwd cost of flash-attn3 in TileLang #917

@imhuim982

Description

@imhuim982

I ran the gqa_bwd in tilelang and replaced the ref_program with the flash-attn3 func, and found the time consumption of tilelang is much larger than the fa3 baseline. I found atomic_add is used instead of the two-stage strategy used in the Triton version. Is it the reason?

$ python example_gqa_bwd.py 
2025-09-30 14:25:08  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `flash_fwd` with `out_idx=[3, 4]`
2025-09-30 14:25:21  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `flash_fwd`
2025-09-30 14:25:21  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `flash_bwd_prep` with `out_idx=[2]`
2025-09-30 14:25:29  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `flash_bwd_prep`
2025-09-30 14:25:29  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `flash_bwd_post` with `out_idx=[1]`
2025-09-30 14:25:38  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `flash_bwd_post`
2025-09-30 14:25:38  [TileLang:tilelang.jit.kernel:INFO]: TileLang begins to compile kernel `flash_bwd` with `out_idx=None`
2025-09-30 14:25:59  [TileLang:tilelang.jit.kernel:INFO]: TileLang completes to compile kernel `flash_bwd`
torch: 1.47 ms
torch: 303.01 TFlops
tilelang: 8.74 ms
tilelang: 51.09 TFlops

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions