Skip to content

Conversation

ch-wan
Copy link
Collaborator

@ch-wan ch-wan commented May 18, 2025

Motivation

Fix one issue when we use DP for dense FFNs and EP (not DeepEP) for sparse FFNs. Related issue: #6297

Modifications

Checklist

@ch-wan ch-wan requested review from BBuf and HaiShaw as code owners May 27, 2025 08:47
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
input_mode=self.layer_scatter_modes.layer_input_mode,
output_mode=self.layer_scatter_modes.attn_mode,
hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering whether it should be

Suggested change
hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
hidden_states_input_mode=self.layer_scatter_modes.layer_input_mode,

output_mode=self.layer_scatter_modes.attn_mode,
hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
residual_input_mode=self.layer_scatter_modes.layer_input_mode,
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
hidden_states_output_mode=self.layer_scatter_modes.attn_mode,

hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
residual_input_mode=self.layer_scatter_modes.layer_input_mode,
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
residual_output_mode=self.layer_scatter_modes.middle_residual_mode,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it is like that, then maybe _communicate_with_all_reduce_and_layer_norm_fn's residual residual mode should be changed

Comment on lines 283 to 288
if (
hidden_states_input_mode == ScatterMode.TP_ATTN_FULL
and residual_input_mode == ScatterMode.SCATTERED
and hidden_states_output_mode == ScatterMode.TP_ATTN_FULL
and residual_output_mode == ScatterMode.SCATTERED
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, this branch looks like "input === output", then we should do nothing, i.e. trivial

maybe the condition is a bit wrong?

and residual_output_mode == ScatterMode.TP_ATTN_FULL
):
return CommunicateSimpleFn._scattered_to_tp_attn_full
return CommunicateSimpleFn._gather_hidden_states_and_residual
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the if looks like we only gathere residual, so wondering maybe if condition is wrong (or the function name is wrong)

Copy link
Collaborator

@fzyzcjy fzyzcjy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@fzyzcjy fzyzcjy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(forget to click approve for the review above)

@zyksir
Copy link
Collaborator

zyksir commented May 30, 2025

@ch-wan maybe you should run "pre-commit run --all-files" to pass the lint tests

@zhyncs zhyncs merged commit 3c2274f into main Jun 16, 2025
0 of 48 checks passed
@zhyncs zhyncs deleted the cheng/gather_before_attn branch June 16, 2025 04:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants