Skip to content

Conversation

@vvvdwbvvv
Copy link
Contributor

Summary

This PR adds support for GLM4.5 (GLM-4 MOE) models to the Liger Kernel #951
https://huggingface.co/zai-org/GLM-4.5 which share the same structure as GLM 4.6

Testing Done

For the convergence test on fp32, model size can easily leads to OOM, initially I was using 4090 to run the tests, however only fp32 encounters OOM, so I move forward to L40S to finish all the tests.

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

skip_logits = self.training and (labels is not None or shift_labels is not None)

if skip_logits:
loss = LigerForCausalLMLoss(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kindly have a look at the other model examples and adapt to new API that returns the metric

@vvvdwbvvv
Copy link
Contributor Author

vvvdwbvvv commented Nov 25, 2025

Fixed in 5af9d16

Comment on lines 2126 to 2140
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
# patch MOE layers
if isinstance(decoder_layer.mlp, Glm4MoeMoE):
experts = decoder_layer.mlp.experts
if experts is not None:
for expert in experts:
_patch_swiglu_module(expert, LigerSwiGLUMLP)

shared_experts = decoder_layer.mlp.shared_experts
if shared_experts is not None:
_patch_swiglu_module(shared_experts, LigerSwiGLUMLP)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
# patch MOE layers
if isinstance(decoder_layer.mlp, Glm4MoeMoE):
experts = decoder_layer.mlp.experts
if experts is not None:
for expert in experts:
_patch_swiglu_module(expert, LigerSwiGLUMLP)
shared_experts = decoder_layer.mlp.shared_experts
if shared_experts is not None:
_patch_swiglu_module(shared_experts, LigerSwiGLUMLP)
if swiglu:
if isinstance(decoder_layer.mlp, Glm4MoeMoE):
experts = decoder_layer.mlp.experts
if experts is not None:
for expert in experts:
_patch_swiglu_module(expert, LigerSwiGLUMLP)
shared_experts = decoder_layer.mlp.shared_experts
if shared_experts is not None:
_patch_swiglu_module(shared_experts, LigerSwiGLUMLP)
elif isinstance(decoder_layer.mlp, Glm4MoeMLP):
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)

You need to check whether this layer is MLP or MoE before patching, plus MoE patching should be under if swiglu scope.

assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(glm4_moe_lce_forward)
assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward)
for decoder_layer in dummy_model_instance.model.layers:
assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, mlp.forward shouldn't be patched if this is a moe layer. Need to check isinstance() first.

assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(glm4_moe_lce_forward)
assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward)
for decoder_layer in dummy_model_instance.base_model.layers:
if decoder_layer.mlp is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check isinstance(decoder_layer.mlp, GLM4MoeMLP) instead of just is not None.

Comment on lines 2644 to 2649
assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) == inspect.getsource(
LigerRMSNormForGlm4.forward
)
assert inspect.getsource(decoder_layer.input_layernorm.forward) == inspect.getsource(
LigerRMSNormForGlm4.forward
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong scope, rms norm should be checked in both mlp layer and moe layer

@vvvdwbvvv
Copy link
Contributor Author

@Tcc0403 Fixed in 876d075

Comment on lines 2217 to 2225
if isinstance(decoder_layer.mlp, Glm4MoeMoE):
experts = decoder_layer.mlp.experts
if experts is not None:
for expert in experts:
_patch_swiglu_module(expert, LigerSwiGLUMLP)

shared_experts = decoder_layer.mlp.shared_experts
if shared_experts is not None:
_patch_swiglu_module(shared_experts, LigerSwiGLUMLP)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not under if swiglu: scope

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I'll fix it

@vvvdwbvvv
Copy link
Contributor Author

@Tcc0403 Fixed in cf9f212

Copy link
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Patching should work, but we need to tune down moe related numbers to speed up convergence tests.

Comment on lines 1146 to 1169
mini_model_config=Glm4MoeConfig(
bos_token_id=1, # None
eos_token_id=2, # 151329, 151336, 151338
pad_token_id=2, # 151329
partial_rotary_factor=0.5,
cross_attention_layers=None,
dropout=0,
hidden_act="silu",
hidden_size=1024, # 6144
initializer_range=0.02,
intermediate_size=2048, # 14336
max_position_embeddings=4096, # 32768
num_attention_heads=8, # 48
num_hidden_layers=4, # 61
num_key_value_heads=2,
rms_norm_eps=1e-5,
rope_scaling=None,
rope_theta=500_000,
tie_word_embeddings=False,
use_cache=True,
vocab_size=32000, # 151552
attention_bias=True,
attn_implementation="sdpa", # default value, pytorch native attention
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set smaller experts related numbers as well

"moe_intermediate_size": 1408,
"num_experts_per_tok": 2,
"n_shared_experts": 1,
"n_routed_experts": 8,
"routed_scaling_factor": 1.0,
"n_group": 1,
"topk_group": 1,
"first_k_dense_replace": 1,

liger_kernel_patch_func=apply_liger_kernel_to_glm4_moe,
liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4_moe,
model_class=Glm4MoeForCausalLM,
mini_model_config=Glm4MoeConfig(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

liger_kernel_patch_func=apply_liger_kernel_to_glm4_moe,
liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4_moe,
model_class=Glm4MoeForCausalLM,
mini_model_config=Glm4MoeConfig(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

liger_kernel_patch_func=apply_liger_kernel_to_glm4_moe,
liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4_moe,
model_class=Glm4MoeForCausalLM,
mini_model_config=Glm4MoeConfig(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@vvvdwbvvv
Copy link
Contributor Author

@Tcc0403 Fixed at 375903c, thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants