-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Implement gather before attn #6378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
self._communicate_simple_fn = CommunicateSimpleFn.get_fn( | ||
input_mode=self.layer_scatter_modes.layer_input_mode, | ||
output_mode=self.layer_scatter_modes.attn_mode, | ||
hidden_states_input_mode=self.layer_scatter_modes.attn_mode, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering whether it should be
hidden_states_input_mode=self.layer_scatter_modes.attn_mode, | |
hidden_states_input_mode=self.layer_scatter_modes.layer_input_mode, |
output_mode=self.layer_scatter_modes.attn_mode, | ||
hidden_states_input_mode=self.layer_scatter_modes.attn_mode, | ||
residual_input_mode=self.layer_scatter_modes.layer_input_mode, | ||
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode, | |
hidden_states_output_mode=self.layer_scatter_modes.attn_mode, |
hidden_states_input_mode=self.layer_scatter_modes.attn_mode, | ||
residual_input_mode=self.layer_scatter_modes.layer_input_mode, | ||
hidden_states_output_mode=self.layer_scatter_modes.mlp_mode, | ||
residual_output_mode=self.layer_scatter_modes.middle_residual_mode, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it is like that, then maybe _communicate_with_all_reduce_and_layer_norm_fn's residual residual mode should be changed
if ( | ||
hidden_states_input_mode == ScatterMode.TP_ATTN_FULL | ||
and residual_input_mode == ScatterMode.SCATTERED | ||
and hidden_states_output_mode == ScatterMode.TP_ATTN_FULL | ||
and residual_output_mode == ScatterMode.SCATTERED | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, this branch looks like "input === output", then we should do nothing, i.e. trivial
maybe the condition is a bit wrong?
and residual_output_mode == ScatterMode.TP_ATTN_FULL | ||
): | ||
return CommunicateSimpleFn._scattered_to_tp_attn_full | ||
return CommunicateSimpleFn._gather_hidden_states_and_residual |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the if
looks like we only gathere residual, so wondering maybe if condition is wrong (or the function name is wrong)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(forget to click approve for the review above)
@ch-wan maybe you should run "pre-commit run --all-files" to pass the lint tests |
Motivation
Fix one issue when we use DP for dense FFNs and EP (not DeepEP) for sparse FFNs. Related issue: #6297
Modifications
Checklist