-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Misc] Use torch.compile for basic custom ops #7110
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @WoosukKwon this PR looks good! I have verified it on the CPU backend and it worked well w/wo the multiprocessing.
To make this PR work on the CPU backend, we should add triton >= 3.0.0
in requirements-cpu.txt
to avoid an import error. It looks like a bug of torch 2.4
vllm/model_executor/custom_op.py
Outdated
""" | ||
if not self._is_compiled and not envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE: | ||
self.forward_static = torch.compile( # type: ignore | ||
self.forward_static, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting dynamic=True
explicitly can reduce recompilations, because of the dynamic batchsize. Maybe the cuda is similar.
vllm/model_executor/custom_op.py
Outdated
self.forward_static = torch.compile( # type: ignore | ||
self.forward_static, | ||
options={ | ||
"fx_graph_cache": True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"fx_graph_cache": True, |
This option causes lock contention when using multiprocessing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you can set a per-process fx_graph_cache
You can set the env var TORCHINDUCTOR_CACHE_DIR
See: https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed. I think we can explore caching in a future PR.
@staticmethod | ||
def forward_static(x: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this completely eliminate the custom silu_and_mul kernel? If so, should it be removed from csrc?
Ditto for the rest of the custom activation ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good question. I think we can delete most of them, while leaving some (e.g., in csrc/legacy
) for potential future use?
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
# forward_native() is too complex to be optimized by torch.compile. | ||
# Fall back to the custom C++ kernel. | ||
return self.forward_cuda(positions, query, key, offsets) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it weird that forward_cpu calls into forward_cuda? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename that if it dispatches to C++ CPU kernel as well?
@WoosukKwon can we expect this PR to merge soon? Is there anything we can do to help it get merged? |
This PR introduces
torch.compile
for the following basic custom ops: activations and RMSNorm.The main goals are:
GemmaRMSNorm
. This leads to 20% throughput improvement for Gemma2-27B on 1xH100.