- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10.9k
[PERF] Allreduce fusion. Support torch native matching. Tuning of the thresholds #24248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[PERF] Allreduce fusion. Support torch native matching. Tuning of the thresholds #24248
Conversation
| This pull request has merge conflicts that must be resolved before it can be | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a significant enhancement to the all-reduce fusion capabilities, adding support for matching native PyTorch operations in addition to custom ops. This greatly improves usability and performance flexibility. The introduction of a comprehensive benchmark for tuning fusion thresholds is also a valuable addition. The changes are extensive, particularly with the large number of new fusion patterns in vllm/compilation/collective_fusion.py. While the overall approach is sound, I've identified several critical issues in the implementation of these new patterns. Specifically, the return values from some pattern and replacement functions appear to be incorrect, which could lead to fusion failures or incorrect model outputs. I've provided detailed comments and suggestions for these issues. The configuration updates and the new benchmark script are well-implemented and welcome improvements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return values from the replacement function are incorrect. The pattern returns (rms_output, allreduce_output), which correspond to the normalized output and the all-reduced tensor. The replacement function should return the same structure.
auto_functionalized(flashinfer_trtllm_fused_allreduce_norm, ...) returns a tuple of 5 mutated arguments: (allreduce_in, residual, norm_out, quant_out, scale_out).
The rms_result corresponds to norm_out, which is allreduce[2].
The allreduce_in (which is input to the replacement function) corresponds to allreduce[0].
Therefore, the return statement should be return allreduce[2], allreduce[0].
The current code returns allreduce[3], allreduce[1], which corresponds to (quant_out, residual). This is incorrect and will lead to fusion failures or wrong results.
| return allreduce[3], allreduce[1] | |
| return allreduce[2], allreduce[0] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return values from the replacement function are incorrect. The pattern returns (rms_output, rms_residual), which are the normalized output and the residual output. The replacement function should return the same structure.
When norm_out=None is passed to flashinfer_trtllm_fused_allreduce_norm, the allreduce_in tensor is used as the output buffer for the normalization result and is mutated. auto_functionalized will return a tuple where the first element (allreduce[0]) is the mutated allreduce_in (i.e., norm_out), and the second element (allreduce[1]) is the mutated residual.
Therefore, the correct return should be return allreduce[0], allreduce[1].
The current code returns allreduce[1], allreduce[2], which corresponds to (residual, norm_out). Since norm_out is None in the call, this is incorrect.
| return allreduce[1], allreduce[2] | |
| return allreduce[0], allreduce[1] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious: why is the threshold still so low for TP8? I think AR+Norm should have pretty good perf up to some larger message sizes for TP8?
        
          
                vllm/config/compilation.py
              
                Outdated
          
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it 1MB for TP8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nvpohanh Here are the results for TP=8 Blackwell with torch symm mem (VLLM_ALLREDUCE_USE_SYMM_MEM=1) enabled (see the set of results below). I used the best performant alternative to fused allreduce. Probably, we can condition on it checking if symm mem is available and enabled, it will overcomplicate the configuration, in my opinion. Compared default allreduce flashinfer fused alternative is not significantly better in 4-16MB region (see results in the end)
Symm mem enabled
World Size: 8
Hidden Dimension: 8192
Warmup Iterations: 5
Benchmark Trials: 20
Quantization Mode: none
Configuration: seq_len=32, dtype=bfloat16, no residual
Input Size: 0.50 MB
| Operation | Time (ms) | Speedup | 
|---|---|---|
| Standard Allreduce Rmsnorm | 0.029 | 1.00x | 
| Standard Allreduce Rmsnorm Native Compiled | 0.030 | 0.99x | 
| Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.012 | 2.39x | 
| Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.086 | 0.34x | 
Configuration: seq_len=64, dtype=bfloat16, no residual
Input Size: 1.00 MB
| Operation | Time (ms) | Speedup | 
|---|---|---|
| Standard Allreduce Rmsnorm | 0.030 | 1.00x | 
| Standard Allreduce Rmsnorm Native Compiled | 0.030 | 0.99x | 
| Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.018 | 1.62x | 
| Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.056 | 0.54x | 
Configuration: seq_len=128, dtype=bfloat16, no residual
Input Size: 2.00 MB
| Operation | Time (ms) | Speedup | 
|---|---|---|
| Standard Allreduce Rmsnorm | 0.023 | 1.00x | 
| Standard Allreduce Rmsnorm Native Compiled | 0.024 | 0.99x | 
| Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.033 | 0.71x | 
| Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.052 | 0.45x | 
Configuration: seq_len=256, dtype=bfloat16, no residual
Input Size: 4.00 MB
| Operation | Time (ms) | Speedup | 
|---|---|---|
| Standard Allreduce Rmsnorm | 0.031 | 0.97x | 
| Standard Allreduce Rmsnorm Native Compiled | 0.030 | baseline | 
| Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.064 | 0.47x | 
| Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.050 | 0.60x | 
Configuration: seq_len=256, dtype=bfloat16, no residual
Input Size: 4.00 MB
| Operation | Time (ms) | Speedup | 
|---|---|---|
| Standard Allreduce Rmsnorm | 0.031 | 0.97x | 
| Standard Allreduce Rmsnorm Native Compiled | 0.030 | baseline | 
| Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.049 | 0.61x | 
Configuration: seq_len=512, dtype=bfloat16, no residual
Input Size: 8.00 MB
| Operation | Time (ms) | Speedup | 
|---|---|---|
| Standard Allreduce Rmsnorm | 0.044 | 0.98x | 
| Standard Allreduce Rmsnorm Native Compiled | 0.043 | baseline | 
| Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.297 | 0.15x | 
Configuration: seq_len=1024, dtype=bfloat16, no residual
Input Size: 16.00 MB
| Operation | Time (ms) | Speedup | 
|---|---|---|
| Standard Allreduce Rmsnorm | 0.071 | 1.00x | 
| Standard Allreduce Rmsnorm Native Compiled | 0.077 | 0.93x | 
| Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.109 | 0.66x | 
Configuration: seq_len=2048, dtype=bfloat16, no residual
Input Size: 32.00 MB
| Operation | Time (ms) | Speedup | 
|---|---|---|
| Standard Allreduce Rmsnorm | 0.135 | 1.00x | 
| Standard Allreduce Rmsnorm Native Compiled | 0.143 | 0.94x | 
| Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.205 | 0.66x | 
Default allreduce
Configuration: seq_len=32, dtype=bfloat16, no residual
Input Size: 0.50 MB
| Operation | Time (ms) | Speedup | 
|---|---|---|
| Standard Allreduce Rmsnorm | 0.029 | 1.00x | 
| Standard Allreduce Rmsnorm Native Compiled | 0.030 | 0.99x | 
| Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.012 | 2.44x | 
| Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.087 | 0.34x | 
Configuration: seq_len=64, dtype=bfloat16, no residual
Input Size: 1.00 MB
| Operation | Time (ms) | Speedup | 
|---|---|---|
| Standard Allreduce Rmsnorm | 0.030 | 1.00x | 
| Standard Allreduce Rmsnorm Native Compiled | 0.030 | 1.00x | 
| Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.019 | 1.63x | 
| Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.056 | 0.54x | 
Configuration: seq_len=128, dtype=bfloat16, no residual
Input Size: 2.00 MB
| Operation | Time (ms) | Speedup | 
|---|---|---|
| Standard Allreduce Rmsnorm | 0.032 | 1.00x | 
| Standard Allreduce Rmsnorm Native Compiled | 0.032 | 1.00x | 
| Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.033 | 0.97x | 
| Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.052 | 0.62x | 
Configuration: seq_len=256, dtype=bfloat16, no residual
Input Size: 4.00 MB
| Operation | Time (ms) | Speedup | 
|---|---|---|
| Standard Allreduce Rmsnorm | 0.051 | 0.98x | 
| Standard Allreduce Rmsnorm Native Compiled | 0.050 | baseline | 
| Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.064 | 0.77x | 
| Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.050 | 1.00x | 
Configuration: seq_len=512, dtype=bfloat16, no residual
Input Size: 8.00 MB
| Operation | Time (ms) | Speedup | 
|---|---|---|
| Standard Allreduce Rmsnorm | 0.079 | 1.00x | 
| Standard Allreduce Rmsnorm Native Compiled | 0.081 | 0.97x | 
| Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.068 | 1.17x | 
Configuration: seq_len=1024, dtype=bfloat16, no residual
Input Size: 16.00 MB
| Operation | Time (ms) | Speedup | 
|---|---|---|
| Standard Allreduce Rmsnorm | 0.119 | 1.00x | 
| Standard Allreduce Rmsnorm Native Compiled | 0.125 | 0.95x | 
| Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.109 | 1.09x | 
Configuration: seq_len=2048, dtype=bfloat16, no residual
Input Size: 32.00 MB
| Operation | Time (ms) | Speedup | 
|---|---|---|
| Standard Allreduce Rmsnorm | 0.195 | 1.00x | 
| Standard Allreduce Rmsnorm Native Compiled | 0.211 | 0.93x | 
| Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.204 | 0.96x | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ilmarkov Is VLLM_ALLREDUCE_USE_SYMM_MEM=1 something that normal vLLM users would set by default? If it's good for performance, why can't we enable it by default? Does it require special environment or special builds? cc @ProExpertProg
@nvjullin Could you check if @ilmarkov 's measurements above match our understanding? Also, could you try if VLLM_ALLREDUCE_USE_SYMM_MEM=1 works in our case? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it can be enabled by default. There is a PR for it. It works on Hopper and Blackwell.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it! we will try both your PRs and run some experiments on our side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ilmarkov Just to clarify: the PyTorch SYMM_MEM implementation does not support AR+Norm fusion, right? So only the AR part uses SYMM_MEM while Norm part is based on native PyT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, symm mem is only for allreduce part, Norm and quantization parts are in native pytorch.
| cc @nvjullin @elvischenv for vis | 
e808818    to
    61ebc95      
    Compare
  
    | This pull request has merge conflicts that must be resolved before it can be | 
| Hi @ilmarkov , is there any progress and ETA for this change? Thanks! | 
| Hi, @nvpohanh . @ProExpertProg works on general custom op matching in #24604. So we will apply allreduce related pattern matching after his PR is landed. I mark current PR as draft for now. | 
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
| This pull request has merge conflicts that must be resolved before it can be | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR.
        
          
                vllm/compilation/fusion.py
              
                Outdated
          
        
      | def empty_bf16(*args, **kwargs): | ||
| return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") | ||
| return torch.empty(*args, **kwargs, dtype=torch.float16, device="cuda") | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
  Restore bfloat16 in pattern placeholders
The helper empty_bf16 now creates tensors with torch.float16 instead of torch.bfloat16. This helper is used throughout the fusion passes (e.g. attention and activation fusion) to trace the FX patterns that should match bfloat16 graphs. Tracing the pattern in float16 means the captured graph contains dtype-specific ops (such as implicit casts) that no longer match the bfloat16 graphs emitted by models, so bfloat16 models will stop triggering these fusion passes. The helper should keep returning torch.bfloat16 to ensure the traced pattern matches bfloat16 execution.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add a test for the default setting of the config param?
| ) | ||
| backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) | ||
| backend.check_after_ops(model.ops_in_model_after()) | ||
| del all_reduce_fusion_pass | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary change?
|  | ||
| import vllm.envs as envs | ||
| from vllm.config import VllmConfig | ||
| from vllm.config import VllmConfig, set_current_vllm_config | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this used?
| self.max_token_num = max_token_num | ||
| self.fuse_rms_quant = fuse_rms_quant | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleanup?
        
          
                vllm/config/compilation.py
              
                Outdated
          
        
      | fi_allreduce_fusion_max_size_mb: dict[int, | ||
| float] = field(default_factory=dict) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| fi_allreduce_fusion_max_size_mb: dict[int, | |
| float] = field(default_factory=dict) | |
| fi_allreduce_fusion_max_size_mb: dict[int, float] = ( | |
| field(default_factory=lambda: deepcopy(resolve_obj_by_qualname("vllm.compilation.fusion_all_reduce._FI_ALLREDUCE_MAX_INPUT_SIZES")) | |
| ) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I see below it's more complex than that. what about:
| fi_allreduce_fusion_max_size_mb: dict[int, | |
| float] = field(default_factory=dict) | |
| fi_allreduce_fusion_max_size_mb: dict[int, float] = ( | |
| field(default_factory=PassConfig.fi_allreduce_fusion_max_size_mb) | |
| ) | 
And then below we can define:
    @staticmethod
    def default_fi_allreduce_fusion_max_size_mb():
        if not current_platform.is_cuda():
            return None
        from vllm.compilation.fusion_all_reduce import FI_ALLREDUCE_FUSION_MAX_SIZE_MB
        
        return deepcopy(FI_ALLREDUCE_FUSION_MAX_SIZE_MB)
| 4: 32 * MiB, # 32MB | ||
| 8: 1 * MiB, # 1MB | ||
| }, | ||
| }, where key is the device capability""" | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's set the default dict to FI_ALLREDUCE_FUSION_MAX_SIZE_MB and then in __post_init__ we can do:
self.fi_allreduce_fusion_max_size_mb = {**FI_ALLREDUCE_FUSION_MAX_SIZE_MB, **self.fi_allreduce_fusion_max_size_mb}
cc @hmellor would this work? Or should we just generate this docstring from _FI_ALLREDUCE_MAX_INPUT_SIZES?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I know, docstrings cannot be generated like that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am also realizing that FI_ALLREDUCE_FUSION_MAX_SIZE_MB has different default for different compute capabilities - so we can't set it as the default for the config.
        
          
                vllm/config/compilation.py
              
                Outdated
          
        
      | device_capability = current_platform.get_device_capability( | ||
| ).as_version_str() | ||
| fi_allreduce_fusion_max_size_mb = \ | ||
| self.fi_allreduce_fusion_max_size_mb.get(device_capability, {}) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the dict was already platform specific?
| assert not isinstance(fused_output, tuple) | ||
| else: | ||
| shared_output, fused_output = torch.ops.vllm.moe_forward_shared( | ||
| fused_output = torch.ops.vllm.moe_forward( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we're changing moe_forward_shared to moe_forward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's in the branch where self.shared_experts is None
| states = self.maybe_all_reduce_tensor_model_parallel(states) | ||
| return states | ||
|  | ||
| if self.shared_experts is not None: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess why invert the logic, seems like the diff is harder to parse due to it (is this because it got inverted in main)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If yes could you restore it so it's easier to read?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use the same orider of the logic as in the forward_impl custom op from which we move the reduction.
| ) | ||
| return fused_output[..., :og_hidden_states] | ||
| return ( | ||
| reduce_output(shared_output[..., :og_hidden_states], do_combine=False), | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where does this slice come from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apparently, moe_forward can return larger tensor than expected. Probably, because of padding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is where the padding is added
vllm/vllm/model_executor/layers/fused_moe/layer.py
Lines 2119 to 2131 in 6c728f7
| def forward_native( | |
| self, | |
| hidden_states: torch.Tensor, | |
| router_logits: torch.Tensor, | |
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | |
| og_hidden_states = hidden_states.shape[-1] | |
| if self.hidden_size != og_hidden_states: | |
| hidden_states = F.pad( | |
| hidden_states, | |
| (0, self.hidden_size - og_hidden_states), | |
| mode="constant", | |
| value=0.0, | |
| ) | 
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
7088940    to
    9516d2b      
    Compare
  
    |  | ||
| @staticmethod | ||
| def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]: | ||
| from vllm.compilation.collective_fusion import FI_ALLREDUCE_FUSION_MAX_SIZE_MB | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docs build is failing because this import now happens when running --help and vllm.compilation.collective_fusion includes a bunch more heavy imports
| and (self.tp_size > 1 or self.ep_size > 1) | ||
| ): | ||
| states = self.maybe_all_reduce_tensor_model_parallel(states) | ||
| return states | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should move the naive dispatch call out to this level also.
Also, the original callsites for naive dispatch/combine are inside a sequence parallel context. I'm not sure if that is going to cause problems.
Signed-off-by: ilmarkov <markovilya197@gmail.com>
| ) | ||
|  | ||
|  | ||
| def standard_allreduce_rmsnorm( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we unify the below methods into a single class using set_current_vllm_config and RMSNorm/QuantFP8 instances to reduce duplicated code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be best to do a model class and parametrize it on residual & quant (none, fp8, fp4) as well as whether each custom op is enabled
| @torch.compile | ||
| def standard_allreduce_rmsnorm_native_compiled( | ||
| input_tensor: torch.Tensor, | ||
| residual: torch.Tensor | None, | ||
| rmsnorm_layer: RMSNorm, | ||
| norm_out: torch.Tensor | None = None, | ||
| ): | ||
| """Compiled version of standard allreduce + rmsnorm.""" | ||
| return standard_allreduce_rmsnorm_native( | ||
| input_tensor, residual, rmsnorm_layer, norm_out | ||
| ) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is way overkill, you can just do:
standard_allreduce_rmsnorm_native_compiled = torch.compile(standard_allreduce_rmsnorm_native)
Also, we should mark the first dimension as dynamic to make sure we're properly simulating vllm codegen. The QuantFP8 benchmark should do this already if you need an example.
| # FlashInfer Fused AllReduce + RMSNorm Oneshot | ||
| if flashinfer_comm is not None and allreduce_params is not None: | ||
| try: | ||
| if not disable_oneshot: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we loop over use_oneshot here?
| @@ -0,0 +1,1281 @@ | |||
| # SPDX-License-Identifier: Apache-2.0 | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file is really long and borderline unreadable, can we compact it a bit better with some more code reuse? Some suggestions below
| """ | ||
|  | ||
| for result in all_results: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use pandas and to_markdown instead of rolling your own?
| description="Benchmark fused collective operations" | ||
| ) | ||
| parser.add_argument( | ||
| "--seq-lens", | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really seq len or is it num_tokens?
| if args.no_quant: | ||
| quant_mode = "none" | ||
| elif args.quant_fp8: | ||
| quant_mode = "fp8_only" | ||
| elif args.quant_fp4: | ||
| quant_mode = "fp4_only" | ||
| else: # args.quant_all or default | ||
| quant_mode = "all" | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not make this a comma-separated list with none, fp8, fp4 as options?
| if use_flashinfer: | ||
| max_tensor_size = max_token_num * hidden_size * element_size | ||
|  | ||
| if current_tensor_size <= max_tensor_size: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just compare num_tokens here? Can always compute current_tensor_size below for the use_oneshot use
| mode="constant", | ||
| value=0.0, | ||
| ) | ||
| do_naive_dispatch_combine: bool = ( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @bnellnm @varun-sundar-rabindranath have you been able to take a look at this?
| 4: 32 * MiB, # 32MB | ||
| 8: 1 * MiB, # 1MB | ||
| }, | ||
| }, where key is the device capability""" | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am also realizing that FI_ALLREDUCE_FUSION_MAX_SIZE_MB has different default for different compute capabilities - so we can't set it as the default for the config.
First part of spliting #22086
Purpose
Add tunings of thresholds for Flashinfer allreduce fusion.
Adds a benchmark for allreduce fusion to determine input size thresholds for flashinfer allreduce.
Updates thresholds for flashinfer allreduce (as well as adding two shot algorithm usage when it has better performance) on Hopper and Blackwell devices
Moves allreduce out of moe_forward custom op in order to be able to match for fusion for moe models.
Test Plan
Added tests for non custom ops fusion
Based on #24604
Review link: https://github.com/vllm-project/vllm/pull/24248/files/6253d5bd143a1975213462e7d6c4f8d3a2e1fef7..7088940db26bdee8554418d92ea060279ea7f523