-
-
Notifications
You must be signed in to change notification settings - Fork 9k
Description
Motivation
In PyTorch 2.6, auto_functionalized_v2
was introduced as a replacement for the auto_functionalized
higher-order, partially to address the issues with redundant tensor copies in vLLM. However, certain custom fusion passes rely on pattern matching and don't currently work with auto_functionalized_v2
.
Due to this as well as a separate issue with V2 (PyTorch#147924), we are currently disabling V2 in PyTorch 2.6+. We have also circumvented the copy issues using a FixFunctionalizationPass
, reducing the urgency for enabling V2.
I am creating this RFC to centralize the discussion about when to upgrade to V2 and how to mitigate it in custom fusion passes.
Motivation for custom passes
Our graph-level optimization system performs graph transformations that would break abstractions or be intrusive to model code in some other way. For example, RMSNormFusionPass
performs manual fusion of RMSNorm and quantization custom ops. A simplified model definition looks like this, whether quantization is enabled or not:
x2 = RMSNorm(x1)
x3 = Linear(x2)
If quantization is on, Linear
consists of a quant
followed by a mm
operation. With RMSNormFusionPass
we fuse the quant
onto the rms_norm
during Inductor compilation, resulting in performance gains without breaking the Linear
abstraction.
Because rms_norm
is in-place, we have to pattern-match the auto_functionalized
op in the FX graph.
Why upgrading is not simple
auto_functionalized_v2
contains additional arguments (more info) to better track tensors and their views to correctly re-inplace the functionalized ops and avoid redundant copies. (They're planning on adding even more arguments soon.)
This makes pattern matching against it more complex (if not impossible) than the current version. If we want to enable V2, we need to choose a strategy for pattern matching and replacement.
Proposed Change
After discussing this with @zou3519, I see a few possible approaches to pattern matching in-place ops. Please feel free to suggest any others.
1. Match patterns before functionalization.
We ran into problems with this in the past, but @zou3519 says it should work™. We should certainly give it a try, as it would make our patterns simpler as well. With custom replacement infrastructure, we could also manually edit the graph instead of using the pattern matcher, and the current code is basically already able to do this.
We might have to do noop-elimination on a non-functional graph (redundant slices & reshapes) to avoid imposing restriction on model code (whole point of custom fusion passes). Because the graph is also not normalized or stable, this might be hard.
2. Match patterns after Inductor re-inplaces functionalized ops.
This would require a new custom hook point from the PyTorch team. Otherwise it has similar pros/cons to option 1, except the noop elimination will have happened already.
3. Introduce functional custom ops.
Instead of wrapping RMSNorm inside auto_functionalized_v2, define a custom rms_norm_functional higher-order operator. Pattern matching would occur on rms_norm_functional + quant instead.
Right now, it seems like approach 2 is best. If we agree on this, it might be worth asking PyTorch for a custom post-functionalization hook (unless I'm missing something and this already exists).
Feedback Period.
This is not urgent yet, we can set a timeline when we establish this to be a priority.
CC List.
@youkaichao @zou3519 @SageMoore @bnellnm @tms @robertgshaw2-redhat
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status