Skip to content

Commit 454e3d2

Browse files
vvvdwbvvvlancertsshimizust
authored
Add GLM4.5V support (#863)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This PR adds support for GLM4.1V (GLM-4 Vision) models to the Liger Kernel #855 https://huggingface.co/zai-org/GLM-4.5 This model have been merged in huggingface/transformers#39805 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> Found that `python3 -m pytest test/convergence/bf16/test_mini_models.py -k 'glm4v_moe' -rF` has `AssertionError: [Loss]Number of mismatched elements: 14` with <details> <summary>Test result</summary> ``` AssertionError: [Loss]Number of mismatched elements: 14 Mismatch at index (0, 5): tensor1[(0, 5)] = 8.733983993530273, tensor2[(0, 5)] = 8.52511215209961 Mismatch at index (0, 8): tensor1[(0, 8)] = 7.2776618003845215, tensor2[(0, 8)] = 7.524500846862793 Mismatch at index (0, 9): tensor1[(0, 9)] = 6.917590618133545, tensor2[(0, 9)] = 7.175967216491699 Mismatch at index (0, 13): tensor1[(0, 13)] = 5.685216426849365, tensor2[(0, 13)] = 5.427236557006836 Mismatch at index (0, 14): tensor1[(0, 14)] = 5.337466239929199, tensor2[(0, 14)] = 5.049449443817139 ... and 9 more mismatched elements. ``` </details> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang <tangshao28@gmail.com> Co-authored-by: Steven Shimizu <shimizust@gmail.com>
1 parent 7812567 commit 454e3d2

File tree

10 files changed

+809
-0
lines changed

10 files changed

+809
-0
lines changed

src/liger_kernel/transformers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
3737
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
3838
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
39+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
3940
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
4041
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
4142
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
@@ -95,6 +96,7 @@ def __getattr__(name: str):
9596
"apply_liger_kernel_to_gemma3_text",
9697
"apply_liger_kernel_to_glm4",
9798
"apply_liger_kernel_to_glm4v",
99+
"apply_liger_kernel_to_glm4v_moe",
98100
"apply_liger_kernel_to_granite",
99101
"apply_liger_kernel_to_llama",
100102
"apply_liger_kernel_to_llava",
@@ -159,6 +161,7 @@ def __getattr__(name: str):
159161
"apply_liger_kernel_to_gemma3_text",
160162
"apply_liger_kernel_to_glm4",
161163
"apply_liger_kernel_to_glm4v",
164+
"apply_liger_kernel_to_glm4v_moe",
162165
"apply_liger_kernel_to_granite",
163166
"apply_liger_kernel_to_llama",
164167
"apply_liger_kernel_to_llava",
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
from typing import Optional
2+
from typing import Tuple
3+
from typing import Union
4+
5+
import torch
6+
7+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeCausalLMOutputWithPast
8+
from transformers.utils.deprecation import deprecate_kwarg
9+
10+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11+
12+
13+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
14+
def lce_forward(
15+
self,
16+
input_ids: torch.LongTensor = None,
17+
attention_mask: Optional[torch.Tensor] = None,
18+
position_ids: Optional[torch.LongTensor] = None,
19+
past_key_values: Optional[list[torch.FloatTensor]] = None,
20+
inputs_embeds: Optional[torch.FloatTensor] = None,
21+
labels: Optional[torch.LongTensor] = None,
22+
pixel_values: Optional[torch.Tensor] = None,
23+
pixel_values_videos: Optional[torch.FloatTensor] = None,
24+
image_grid_thw: Optional[torch.LongTensor] = None,
25+
video_grid_thw: Optional[torch.LongTensor] = None,
26+
rope_deltas: Optional[torch.LongTensor] = None,
27+
cache_position: Optional[torch.LongTensor] = None,
28+
logits_to_keep: Union[int, torch.Tensor] = 0,
29+
skip_logits: Optional[bool] = None,
30+
**kwargs,
31+
) -> Union[Tuple, Glm4vMoeCausalLMOutputWithPast]:
32+
r"""
33+
Args:
34+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
35+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
36+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
37+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
38+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
39+
The temporal, height and width of feature shape of each image in LLM.
40+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
41+
The temporal, height and width of feature shape of each video in LLM.
42+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
43+
The rope index difference between sequence length and multimodal rope.
44+
45+
46+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
47+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
48+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
49+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
50+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
51+
This is useful when using packed tensor format (single dimension for batch and sequence length).
52+
53+
Example:
54+
55+
```python
56+
>>> from transformers import AutoProcessor, Glm4vMoeForConditionalGeneration
57+
>>> import torch
58+
59+
>>> MODEL_PATH = "zai-org/GLM-4.5V"
60+
>>> messages = [
61+
{
62+
"role": "user",
63+
"content": [
64+
{
65+
"type": "image",
66+
"url": "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png"
67+
},
68+
{
69+
"type": "text",
70+
"text": "describe this image"
71+
}
72+
],
73+
}
74+
]
75+
>>> processor = AutoProcessor.from_pretrained(MODEL_PATH)
76+
>>> model = Glm4vMoeForConditionalGeneration.from_pretrained(
77+
pretrained_model_name_or_path=MODEL_PATH,
78+
torch_dtype="auto",
79+
device_map="auto",
80+
)
81+
>>> inputs = processor.apply_chat_template(
82+
messages,
83+
tokenize=True,
84+
add_generation_prompt=True,
85+
return_dict=True,
86+
return_tensors="pt"
87+
).to(model.device)
88+
>>> inputs.pop("token_type_ids", None)
89+
>>> generated_ids = model.generate(**inputs, max_new_tokens=8192)
90+
>>> output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
91+
```
92+
"""
93+
94+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
95+
outputs = self.model(
96+
input_ids=input_ids,
97+
pixel_values=pixel_values,
98+
pixel_values_videos=pixel_values_videos,
99+
image_grid_thw=image_grid_thw,
100+
video_grid_thw=video_grid_thw,
101+
position_ids=position_ids,
102+
attention_mask=attention_mask,
103+
past_key_values=past_key_values,
104+
inputs_embeds=inputs_embeds,
105+
cache_position=cache_position,
106+
**kwargs,
107+
)
108+
109+
hidden_states = outputs[0]
110+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
111+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
112+
kept_hidden_states = hidden_states[:, slice_indices, :]
113+
114+
shift_labels = kwargs.pop("shift_labels", None)
115+
logits = None
116+
loss = None
117+
118+
if skip_logits and labels is None and shift_labels is None:
119+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
120+
121+
if skip_logits is None:
122+
# By default, if in training mode, don't materialize logits
123+
skip_logits = self.training and (labels is not None or shift_labels is not None)
124+
125+
if skip_logits:
126+
loss = LigerForCausalLMLoss(
127+
hidden_states=kept_hidden_states,
128+
lm_head_weight=self.lm_head.weight,
129+
labels=labels,
130+
shift_labels=shift_labels,
131+
hidden_size=self.config.hidden_size,
132+
**kwargs,
133+
)
134+
135+
else:
136+
logits = self.lm_head(kept_hidden_states)
137+
if labels is not None:
138+
loss = self.loss_function(
139+
logits=logits,
140+
labels=labels,
141+
vocab_size=self.config.vocab_size,
142+
**kwargs,
143+
)
144+
145+
return Glm4vMoeCausalLMOutputWithPast(
146+
loss=loss,
147+
logits=logits,
148+
past_key_values=outputs.past_key_values,
149+
hidden_states=outputs.hidden_states,
150+
attentions=outputs.attentions,
151+
rope_deltas=outputs.rope_deltas,
152+
)

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,6 +1928,107 @@ def apply_liger_kernel_to_glm4v(
19281928
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
19291929

19301930

1931+
def apply_liger_kernel_to_glm4v_moe(
1932+
rope: bool = False,
1933+
cross_entropy: bool = False,
1934+
fused_linear_cross_entropy: bool = True,
1935+
rms_norm: bool = True,
1936+
swiglu: bool = True,
1937+
model: PreTrainedModel = None,
1938+
) -> None:
1939+
"""
1940+
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
1941+
1942+
Args:
1943+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
1944+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1945+
fused_linear_cross_entropy (bool):
1946+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
1947+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1948+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1949+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1950+
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
1951+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1952+
loaded. Default is None.
1953+
"""
1954+
assert not (cross_entropy and fused_linear_cross_entropy), (
1955+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
1956+
)
1957+
1958+
from transformers.models.glm4v_moe import modeling_glm4v_moe
1959+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeForConditionalGeneration
1960+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeModel
1961+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextModel
1962+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeVisionModel
1963+
1964+
from liger_kernel.transformers.model.glm4v_moe import lce_forward as glm4v_moe_lce_forward
1965+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
1966+
1967+
if rope:
1968+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1969+
if rms_norm:
1970+
modeling_glm4v_moe.Glm4vRMSNorm = LigerRMSNormForGlm4
1971+
if cross_entropy:
1972+
from transformers.loss.loss_utils import nn
1973+
1974+
nn.functional.cross_entropy = liger_cross_entropy
1975+
if fused_linear_cross_entropy:
1976+
if model is not None:
1977+
model.forward = MethodType(glm4v_moe_lce_forward, model)
1978+
else:
1979+
modeling_glm4v_moe.Glm4vMoeForConditionalGeneration.forward = glm4v_moe_lce_forward
1980+
1981+
if model is not None:
1982+
# The model instance already exists, so we need to additionally patch the
1983+
# instance variables that reference already-instantiated modules
1984+
if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
1985+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
1986+
# Not sure if it is subject to changes in the future.
1987+
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
1988+
text_model: Glm4vMoeTextModel = model.language_model
1989+
vision_model: Glm4vMoeVisionModel = model.visual
1990+
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
1991+
elif isinstance(model, Glm4vMoeTextModel):
1992+
text_model: Glm4vMoeTextModel = model
1993+
vision_model = None
1994+
else:
1995+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1996+
raise TypeError(
1997+
f"Unsupported glm4v_moe model type. `model` must be `Glm4vMoeForConditionalGeneration`, `Glm4vMoeVisionModel` or `Glm4vMoeTextModel`. Got: {type(model)}"
1998+
)
1999+
2000+
if vision_model is not None:
2001+
_patch_rms_norm_module(vision_model.post_conv_layernorm)
2002+
_patch_rms_norm_module(vision_model.post_layernorm)
2003+
for vision_block in vision_model.blocks:
2004+
if rms_norm:
2005+
_patch_rms_norm_module(vision_block.norm1)
2006+
_patch_rms_norm_module(vision_block.norm2)
2007+
if swiglu:
2008+
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
2009+
2010+
if text_model is not None:
2011+
if rms_norm:
2012+
_patch_rms_norm_module(text_model.norm)
2013+
for decoder_layer in text_model.layers:
2014+
if swiglu:
2015+
decoder_layer.mlp = _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2016+
if rms_norm:
2017+
_patch_rms_norm_module(decoder_layer.input_layernorm)
2018+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2019+
if isinstance(Glm4vMoeTextMoE, type) and isinstance(decoder_layer.mlp, Glm4vMoeTextMoE):
2020+
experts = getattr(decoder_layer.mlp, "experts", None)
2021+
if experts is not None:
2022+
for expert in experts:
2023+
_patch_swiglu_module(expert, LigerSwiGLUMLP)
2024+
if decoder_layer.mlp.shared_experts is not None:
2025+
_patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP)
2026+
for decoder_layer in text_model.layers:
2027+
if rms_norm:
2028+
_patch_rms_norm_module(decoder_layer.input_layernorm)
2029+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2030+
2031+
19312032
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
19322033
MODEL_TYPE_TO_APPLY_LIGER_FN = {
19332034
"gemma": apply_liger_kernel_to_gemma,
@@ -1936,6 +2037,7 @@ def apply_liger_kernel_to_glm4v(
19362037
"gemma3": apply_liger_kernel_to_gemma3,
19372038
"glm4": apply_liger_kernel_to_glm4,
19382039
"glm4v": apply_liger_kernel_to_glm4v,
2040+
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
19392041
"llama": apply_liger_kernel_to_llama,
19402042
"llama4_text": apply_liger_kernel_to_llama4,
19412043
"llama4": apply_liger_kernel_to_llama4,

0 commit comments

Comments
 (0)