-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[MULTI-GPU] Optimize reduce_scatter (except all-to-all) using custom triton kernels #8300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Update Update Update Update Update Fix Update Update Update Update Lint Update Update Update Update Update Update Update Update Update Update Update Update Update Lint Update Update Update Lint Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Fix Update Fix Fix Update Revert Update Update Temporary update Update Update Update Update Update Update Update Update Fix Update Update Update Try Update Update Remove i64 Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting
if position != -1: | ||
row_offsets = position * hidden_size | ||
col_offsets = row_offsets + offs_n | ||
|
||
for tp_idx in tl.static_range(TP): | ||
values = tl.load( | ||
input_ptrs[tl.constexpr(ep_idx * TP + tp_idx)] + col_offsets, | ||
mask=io_mask, | ||
other=0, | ||
).to(intermediate_dtype) | ||
output += values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Use Triton mask instead of Python branch
The new Triton accumulation kernel attempts to gate loads with if position != -1:
. position
is a runtime tl.int32
tensor, so evaluating it in a Python if
will raise TypeError: bool value of Tensor is ambiguous
during JIT compilation. As a result the Triton kernel never compiles and the CommKernelType.TRITON
path fails immediately. Use a tensor mask (tl.where
/tl.load
with masks) instead of Python control flow so the kernel can compile.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@codex position is a scalar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary
- Replaced the Python-side conditional in the EP accumulation Triton kernel with tensor-based masking so the kernel compiles while skipping invalid positions during loads. python/triton_kernels/bench/distributed.pyL178-L193
Testing
⚠️ Tests not run (not requested)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting
_accumulate_ep_triton_kernel[(n_tokens, )]( | ||
positions, | ||
output_tensor, | ||
tuple(output_list), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Triton kernel called with unsupported tuple argument
The new _reduce_ep_triton
path launches _accumulate_ep_triton_kernel
and passes tuple(output_list)
as a single kernel argument. Triton kernel parameters must be individual tensors or constexpr literals; a Python tuple is not a valid runtime value and the launch will raise a TypeError
before any computation runs. As written, the TRITON communication path introduced in this commit cannot execute at all. Consider passing the tensors as separate arguments or materializing a device-side pointer array and indexing that inside the kernel.
Useful? React with 👍 / 👎.
offs_m = tl.program_id(0) | ||
token_mask = offs_m < n_tokens | ||
offs_n = tl.arange(0, BLOCK_SIZE_N) | ||
feature_mask = offs_n < hidden_size | ||
io_mask = token_mask[:] & feature_mask | ||
output = tl.zeros((BLOCK_SIZE_N, ), dtype=intermediate_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Scalar mask indexed as vector in Triton kernel
Inside _accumulate_ep_triton_kernel
the mask is computed as io_mask = token_mask[:] & feature_mask
. token_mask
is a scalar (offs_m < n_tokens
) so subscripting it with [:]
is invalid in Triton and causes compilation to fail before the kernel can launch. This prevents the custom reduce‑scatter kernel from building. A scalar mask can be broadcast directly with token_mask & feature_mask
or by constructing a vector mask of the appropriate shape.
Useful? React with 👍 / 👎.
At context length 32768.
Before:
After
Reduce is not the bottleneck now but instead communication is.