Skip to content

FlexAttention ModIndex misses cache hit for autograd func #151358

Open
@drisspg

Description

@drisspg

Summary

vllm-project/vllm#16078, while working on this Richard and I noticed that we are missing cache on repeated runs to "compile_block_mask" because of mod_index autograd func

return mod_index(args[0], index_args)

Fix is to check if grad_mod is enabled / x requries grad. If so run func else: call contents of foward

cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @yanboliang @BoyuanFeng

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: flex attentionmodule: higher order operatorstorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions