Skip to content

[Torch] Support mask mod arguments in flex_attention HOP#4570

Open
keshavvinayak01 wants to merge 2 commits into
llvm:mainfrom
iree-org:flex-mask-mod-args
Open

[Torch] Support mask mod arguments in flex_attention HOP#4570
keshavvinayak01 wants to merge 2 commits into
llvm:mainfrom
iree-org:flex-mask-mod-args

Conversation

@keshavvinayak01

@keshavvinayak01 keshavvinayak01 commented May 13, 2026

Copy link
Copy Markdown
Contributor

In traced/exported FX graphs, Dynamo represents mask_mod closure captures for the flex_attention HOP as the final mask_mod_other_buffers argument. torch.hop_flex_attention only modelled the mask function symbol, so those captures could not be represented in Torch IR.

This adds trailing variadic mask_mod_other_buffers operands to torch.hop_flex_attention and prints them with explicit mask_mod_other_buffers(...) syntax. The FX importer forwards only exported mask_mod_other_buffers values into those operands. score_mod_other_buffers are named and rejected with an explicit unsupported error instead of being silently dropped.

@keshavvinayak01 keshavvinayak01 marked this pull request as ready for review May 13, 2026 15:26
@keshavvinayak01 keshavvinayak01 force-pushed the flex-mask-mod-args branch 2 times, most recently from 5ff3fba to 819aea0 Compare May 13, 2026 15:35
Comment thread python/torch_mlir/extras/fx_importer.py Outdated
Comment thread python/torch_mlir/extras/fx_importer.py Outdated
Comment thread include/torch-mlir/Dialect/Torch/IR/TorchOps.td Outdated
Comment thread test/Dialect/Torch/ops.mlir Outdated
@rkayaith rkayaith requested a review from zjgarvey May 14, 2026 16:22

@zjgarvey zjgarvey left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm mostly concerned about the silently dropped arg.

Another thing I'd like to ask is if you would be willing to add verifier logic for at least checking operand count vs. mask_mod_fn arity (or at least something meaningful to cover the new case).

Comment thread python/torch_mlir/extras/fx_importer.py
Comment thread python/torch_mlir/extras/fx_importer.py Outdated
Comment thread include/torch-mlir/Dialect/Torch/IR/TorchOps.td
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>

@keshavvinayak01 keshavvinayak01 left a comment

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reviews, please check again.

@keshavvinayak01 keshavvinayak01 requested review from rkayaith and zjgarvey and removed request for rkayaith May 18, 2026 12:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants