-
Notifications
You must be signed in to change notification settings - Fork 362
refactor allgather/mc2-related fused_experts #2369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the Mixture-of-Experts (MoE) token dispatching logic by introducing a base class MoETokenDispatcher
and several implementations for different strategies. It also adds a new test suite for the UnquantizedTokenDispatcherWithMC2
class. While the refactoring improves structure, I've identified several critical issues in both the new implementation and the tests that must be addressed. The tests contain incorrect mock paths and will fail due to a KeyError
. The implementation has critical bugs such as overwriting initialized variables with None
, using undefined attributes, and incorrect usage of super()
. These issues impact correctness and maintainability.
self.need_param = {} # Replace with actual parameters if needed | ||
self.dispatcher = UnquantizedTokenDispatcherWithMC2(need_param=self.need_param) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The need_param
dictionary is initialized as empty. However, the UnquantizedTokenDispatcherWithMC2
class (via its parent MoETokenDispatcher
) expects top_k
and num_experts
keys to be present in this dictionary during initialization. This will lead to a KeyError
when running the tests. Please provide the necessary parameters.
self.need_param = {} # Replace with actual parameters if needed | |
self.dispatcher = UnquantizedTokenDispatcherWithMC2(need_param=self.need_param) | |
self.need_param = {"top_k": 2, "num_experts": 8} # Example values | |
self.dispatcher = UnquantizedTokenDispatcherWithMC2(need_param=self.need_param) |
self.patcher_mc2_group = mock.patch('your_module.get_mc2_group', mock_get_mc2_group) | ||
self.patcher_mc2_group.start() | ||
|
||
# Mock ascend config | ||
mock_ascend_config = mock.Mock() | ||
mock_ascend_config.torchair_graph_config.enabled = False | ||
self.patcher_ascend_config = mock.patch('your_module.get_ascend_config', return_value=mock_ascend_config) | ||
self.patcher_ascend_config.start() | ||
|
||
# Mock ascend soc version | ||
self.patcher_ascend_version = mock.patch('your_module.get_ascend_soc_version', return_value=AscendSocVersion.A3) | ||
self.patcher_ascend_version.start() | ||
|
||
# Mock forward context | ||
mock_forward_context = mock.Mock() | ||
mock_forward_context.mc2_mask = torch.tensor([1, 0, 1]) # Example mask | ||
self.patcher_forward_context = mock.patch('your_module.get_forward_context', return_value=mock_forward_context) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mock paths for get_mc2_group
, get_ascend_config
, get_ascend_soc_version
, and get_forward_context
are incorrect, using the placeholder 'your_module'
. The patch target must be the location where the name is looked up. In this case, these functions are imported and used within the vllm_ascend.ops.moe_dispatcher.token_dispatcher
module. The tests will fail to patch the correct objects and will likely fail. Please correct these paths. For example, mock.patch('your_module.get_mc2_group', ...)
should be mock.patch('vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group', ...)
. A similar issue exists in test_a3_extra_args_handling
.
self.ep_rank_id = None | ||
self.ep_world_size = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.ep_rank_id = None | ||
self.ep_world_size = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.bsz = None | ||
device_group = get_ep_group().device_group | ||
self.ep_size = device_group.world_size | ||
self.local_num_experts = self.global_num_experts // self.ep_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The attribute self.global_num_experts
is used here but it is not defined in this class or its parent MoETokenDispatcher
. This will raise an AttributeError
. It seems you intended to use self.num_experts
, which is available from the parent class.
self.local_num_experts = self.global_num_experts // self.ep_size | |
self.local_num_experts = self.num_experts // self.ep_size |
|
||
class UnquantizedTokenDispatcherWithMC2(MoETokenDispatcher): | ||
def __init__(self, need_param): | ||
super(MoETokenDispatcher, self).__init__(need_param=need_param) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The super()
call super(MoETokenDispatcher, self)
is incorrect. It should be super().__init__(need_param=need_param)
for modern Python, or super(UnquantizedTokenDispatcherWithMC2, self).__init__(need_param=need_param)
. Using the base class in super()
is not standard and can cause method resolution issues. This pattern is repeated in other dispatcher classes in this file (QuantizedTokenDispatcherWithMC2
, QuantizedTokenDispatcherWithAllGather
, UnquantizedTokenDispatcherWithFusedExpertsMoge
).
super(MoETokenDispatcher, self).__init__(need_param=need_param) | |
super().__init__(need_param=need_param) |
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
6aff3e9
to
891282f
Compare
you can ignore other ut failure. But please make sure the test related to this PR pass:
|
class TestTokenDispatcherWithMC2(unittest.TestCase): | ||
|
||
def setUp(self): | ||
# Mock get_mc2_group() 返回固定值 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove chinese in this file
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
What this PR does / why we need it?
refactor allgather/mc2-related fused_experts
Does this PR introduce any user-facing change?
How was this patch tested?