Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ loss.backward()
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |


## Low-level APIs
Expand Down
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
Expand Down Expand Up @@ -79,6 +80,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_gemma2",
"apply_liger_kernel_to_gemma3",
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_granite",
"apply_liger_kernel_to_llama",
"apply_liger_kernel_to_llava",
Expand Down Expand Up @@ -129,6 +131,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_gemma2",
"apply_liger_kernel_to_gemma3",
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_granite",
"apply_liger_kernel_to_llama",
"apply_liger_kernel_to_llava",
Expand Down
123 changes: 123 additions & 0 deletions src/liger_kernel/transformers/model/glm4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import torch

from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.glm4.modeling_glm4 import _CONFIG_FOR_DOC
from transformers.models.glm4.modeling_glm4 import GLM4_INPUTS_DOCSTRING
from transformers.utils import add_start_docstrings_to_model_forward
from transformers.utils import replace_return_docstrings
from transformers.utils.deprecation import deprecate_kwarg

from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss


@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GLM4_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).

Returns:

Example:

```python
>>> from transformers import AutoTokenizer, Glm4ForCausalLM

>>> model = Glm4ForCausalLM.from_pretrained("THUDM/GLM-4-9B-0414")
>>> tokenizer = AutoTokenizer.from_pretrained("THUDM/GLM-4-9B-0414")

>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)

hidden_states = outputs[0]

shift_labels = loss_kwargs.pop("shift_labels", None)
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None or shift_labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**loss_kwargs,
)

else: # if in inference mode materialize logits
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**loss_kwargs,
)

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
65 changes: 65 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
Expand Down Expand Up @@ -1319,12 +1320,76 @@ def apply_liger_kernel_to_olmo2(
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)


def apply_liger_kernel_to_glm4(
rope: bool = False,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.

Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)

from transformers.models.glm4 import modeling_glm4
from transformers.models.glm4.modeling_glm4 import Glm4Model

if rope:
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
if rms_norm:
modeling_glm4.Glm4RMSNorm = partial(LigerRMSNorm, in_place=False)
if swiglu:
modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
if cross_entropy:
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward

if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules

# get the base model from the model instance
base_model: Glm4Model = getattr(model, model.base_model_prefix, model)

if rms_norm:
_patch_rms_norm_module(base_model.norm, in_place=False)

for decoder_layer in base_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False)
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)


# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
MODEL_TYPE_TO_APPLY_LIGER_FN = {
"gemma": apply_liger_kernel_to_gemma,
"gemma2": apply_liger_kernel_to_gemma2,
"gemma3_text": apply_liger_kernel_to_gemma3_text,
"gemma3": apply_liger_kernel_to_gemma3,
"glm4": apply_liger_kernel_to_glm4,
"llama": apply_liger_kernel_to_llama,
"llava": apply_liger_kernel_to_llava,
"granite": apply_liger_kernel_to_granite,
Expand Down
63 changes: 63 additions & 0 deletions test/convergence/bf16/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from liger_kernel.transformers import apply_liger_kernel_to_gemma
from liger_kernel.transformers import apply_liger_kernel_to_gemma2
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text
from liger_kernel.transformers import apply_liger_kernel_to_glm4
from liger_kernel.transformers import apply_liger_kernel_to_granite
from liger_kernel.transformers import apply_liger_kernel_to_llama
from liger_kernel.transformers import apply_liger_kernel_to_llava
Expand All @@ -38,6 +39,7 @@
from test.utils import revert_liger_kernel_to_gemma
from test.utils import revert_liger_kernel_to_gemma2
from test.utils import revert_liger_kernel_to_gemma3_text
from test.utils import revert_liger_kernel_to_glm4
from test.utils import revert_liger_kernel_to_granite
from test.utils import revert_liger_kernel_to_llama
from test.utils import revert_liger_kernel_to_llava
Expand Down Expand Up @@ -106,6 +108,14 @@
except ImportError:
OLMO2_AVAILABLE = False

try:
# Glm4 is only available in transformers>=4.51.3
from transformers.models.glm4.configuration_glm4 import Glm4Config
from transformers.models.glm4.modeling_glm4 import Glm4ForCausalLM

GLM4_AVAILABLE = True
except ImportError:
GLM4_AVAILABLE = False

try:
from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig
Expand Down Expand Up @@ -644,6 +654,37 @@
),
)

if GLM4_AVAILABLE:
MINI_MODEL_SETUPS["mini_glm4"] = MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_glm4,
liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4,
model_class=Glm4ForCausalLM,
mini_model_config=Glm4Config(
bos_token_id=1, # None
eos_token_id=2, # 151329, 151336, 151338
pad_token_id=2, # 151329
partial_rotary_factor=0.5,
cross_attention_layers=None,
dropout=0,
hidden_act="silu",
hidden_size=1024, # 6144
initializer_range=0.02,
intermediate_size=2048, # 14336
max_position_embeddings=4096, # 32768
num_attention_heads=8, # 48
num_hidden_layers=4, # 61
num_key_value_heads=2,
rms_norm_eps=1e-5,
rope_scaling=None,
rope_theta=500_000,
tie_word_embeddings=False,
use_cache=True,
vocab_size=32000, # 151552
attention_bias=True,
attn_implementation="sdpa", # default value, pytorch native attention
),
)


def create_model(model_name="mini_llama3"):
"""
Expand Down Expand Up @@ -679,6 +720,9 @@ def run_mini_model(
"rms_norm": True,
}

if "glm4" in model_name:
kwargs["rope"] = False

model_supports_layer_norm = "qwen2_vl" in model_name
if model_supports_layer_norm:
kwargs["layer_norm"] = True
Expand Down Expand Up @@ -890,6 +934,25 @@ def run_mini_model(
),
],
),
pytest.param(
"mini_glm4",
32,
1e-4,
torch.bfloat16,
1e-3,
1e-2,
1e-1,
1e-2,
1e-2,
1e-2,
marks=[
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
pytest.mark.skipif(
not GLM4_AVAILABLE,
reason="Glm4 not available in this version of transformers",
),
],
),
# TODO: mixtral is flaky so disable the test for now
# pytest.param(
# "mini_mixtral",
Expand Down
Loading
Loading