-
Notifications
You must be signed in to change notification settings - Fork 457
Add GLM4_MOE model support #952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| skip_logits = self.training and (labels is not None or shift_labels is not None) | ||
|
|
||
| if skip_logits: | ||
| loss = LigerForCausalLMLoss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kindly have a look at the other model examples and adapt to new API that returns the metric
|
Fixed in 5af9d16 |
| if swiglu: | ||
| _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) | ||
| if rms_norm: | ||
| _patch_rms_norm_module(decoder_layer.input_layernorm) | ||
| _patch_rms_norm_module(decoder_layer.post_attention_layernorm) | ||
| # patch MOE layers | ||
| if isinstance(decoder_layer.mlp, Glm4MoeMoE): | ||
| experts = decoder_layer.mlp.experts | ||
| if experts is not None: | ||
| for expert in experts: | ||
| _patch_swiglu_module(expert, LigerSwiGLUMLP) | ||
|
|
||
| shared_experts = decoder_layer.mlp.shared_experts | ||
| if shared_experts is not None: | ||
| _patch_swiglu_module(shared_experts, LigerSwiGLUMLP) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if swiglu: | |
| _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) | |
| if rms_norm: | |
| _patch_rms_norm_module(decoder_layer.input_layernorm) | |
| _patch_rms_norm_module(decoder_layer.post_attention_layernorm) | |
| # patch MOE layers | |
| if isinstance(decoder_layer.mlp, Glm4MoeMoE): | |
| experts = decoder_layer.mlp.experts | |
| if experts is not None: | |
| for expert in experts: | |
| _patch_swiglu_module(expert, LigerSwiGLUMLP) | |
| shared_experts = decoder_layer.mlp.shared_experts | |
| if shared_experts is not None: | |
| _patch_swiglu_module(shared_experts, LigerSwiGLUMLP) | |
| if swiglu: | |
| if isinstance(decoder_layer.mlp, Glm4MoeMoE): | |
| experts = decoder_layer.mlp.experts | |
| if experts is not None: | |
| for expert in experts: | |
| _patch_swiglu_module(expert, LigerSwiGLUMLP) | |
| shared_experts = decoder_layer.mlp.shared_experts | |
| if shared_experts is not None: | |
| _patch_swiglu_module(shared_experts, LigerSwiGLUMLP) | |
| elif isinstance(decoder_layer.mlp, Glm4MoeMLP): | |
| _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) | |
| if rms_norm: | |
| _patch_rms_norm_module(decoder_layer.input_layernorm) | |
| _patch_rms_norm_module(decoder_layer.post_attention_layernorm) |
You need to check whether this layer is MLP or MoE before patching, plus MoE patching should be under if swiglu scope.
| assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(glm4_moe_lce_forward) | ||
| assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) | ||
| for decoder_layer in dummy_model_instance.model.layers: | ||
| assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, mlp.forward shouldn't be patched if this is a moe layer. Need to check isinstance() first.
| assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(glm4_moe_lce_forward) | ||
| assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) | ||
| for decoder_layer in dummy_model_instance.base_model.layers: | ||
| if decoder_layer.mlp is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check isinstance(decoder_layer.mlp, GLM4MoeMLP) instead of just is not None.
| assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) == inspect.getsource( | ||
| LigerRMSNormForGlm4.forward | ||
| ) | ||
| assert inspect.getsource(decoder_layer.input_layernorm.forward) == inspect.getsource( | ||
| LigerRMSNormForGlm4.forward | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong scope, rms norm should be checked in both mlp layer and moe layer
79c22df to
5cf027e
Compare
| if isinstance(decoder_layer.mlp, Glm4MoeMoE): | ||
| experts = decoder_layer.mlp.experts | ||
| if experts is not None: | ||
| for expert in experts: | ||
| _patch_swiglu_module(expert, LigerSwiGLUMLP) | ||
|
|
||
| shared_experts = decoder_layer.mlp.shared_experts | ||
| if shared_experts is not None: | ||
| _patch_swiglu_module(shared_experts, LigerSwiGLUMLP) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not under if swiglu: scope
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, I'll fix it
Tcc0403
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Patching should work, but we need to tune down moe related numbers to speed up convergence tests.
| mini_model_config=Glm4MoeConfig( | ||
| bos_token_id=1, # None | ||
| eos_token_id=2, # 151329, 151336, 151338 | ||
| pad_token_id=2, # 151329 | ||
| partial_rotary_factor=0.5, | ||
| cross_attention_layers=None, | ||
| dropout=0, | ||
| hidden_act="silu", | ||
| hidden_size=1024, # 6144 | ||
| initializer_range=0.02, | ||
| intermediate_size=2048, # 14336 | ||
| max_position_embeddings=4096, # 32768 | ||
| num_attention_heads=8, # 48 | ||
| num_hidden_layers=4, # 61 | ||
| num_key_value_heads=2, | ||
| rms_norm_eps=1e-5, | ||
| rope_scaling=None, | ||
| rope_theta=500_000, | ||
| tie_word_embeddings=False, | ||
| use_cache=True, | ||
| vocab_size=32000, # 151552 | ||
| attention_bias=True, | ||
| attn_implementation="sdpa", # default value, pytorch native attention | ||
| ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set smaller experts related numbers as well
"moe_intermediate_size": 1408,
"num_experts_per_tok": 2,
"n_shared_experts": 1,
"n_routed_experts": 8,
"routed_scaling_factor": 1.0,
"n_group": 1,
"topk_group": 1,
"first_k_dense_replace": 1,
| liger_kernel_patch_func=apply_liger_kernel_to_glm4_moe, | ||
| liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4_moe, | ||
| model_class=Glm4MoeForCausalLM, | ||
| mini_model_config=Glm4MoeConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| liger_kernel_patch_func=apply_liger_kernel_to_glm4_moe, | ||
| liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4_moe, | ||
| model_class=Glm4MoeForCausalLM, | ||
| mini_model_config=Glm4MoeConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| liger_kernel_patch_func=apply_liger_kernel_to_glm4_moe, | ||
| liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4_moe, | ||
| model_class=Glm4MoeForCausalLM, | ||
| mini_model_config=Glm4MoeConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
Summary
This PR adds support for GLM4.5 (GLM-4 MOE) models to the Liger Kernel #951
https://huggingface.co/zai-org/GLM-4.5 which share the same structure as GLM 4.6
Testing Done
For the convergence test on fp32, model size can easily leads to OOM, initially I was using 4090 to run the tests, however only fp32 encounters OOM, so I move forward to L40S to finish all the tests.
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence