Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Gemma RMSNorm #85

Merged
merged 2 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
add gemma to convergence tests
  • Loading branch information
davidgonmar committed Aug 25, 2024
commit 683002aa2f1cab8b4e8c37a9313401d206e7bcdd
4 changes: 2 additions & 2 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def apply_liger_kernel_to_gemma(
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
"""
# TODO(yundai424): add convergence test for gemma
from transformers.models.gemma import modeling_gemma

if rope:
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_gemma.GemmaRMSNorm = partial(LigerRMSNorm, offset=1.0)
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
modeling_gemma.GemmaRMSNorm = partial(LigerRMSNorm, offset=1.0, init_fn="zeros")
if cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if geglu:
Expand Down
10 changes: 8 additions & 2 deletions src/liger_kernel/transformers/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@


class LigerRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6, offset=0.0):
def __init__(self, hidden_size, eps=1e-6, offset=0.0, init_fn="ones"):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
assert init_fn in [
"ones",
"zeros",
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
self.weight = nn.Parameter(
torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
)
self.variance_epsilon = eps
self.offset = offset

Expand Down
46 changes: 43 additions & 3 deletions test/convergence/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
import torch
from datasets import load_from_disk
from torch.utils.data import DataLoader
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
from transformers.models.mistral import MistralConfig, MistralForCausalLM
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM

from liger_kernel.transformers import (
apply_liger_kernel_to_gemma,
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
Expand Down Expand Up @@ -58,6 +60,34 @@
attn_implementation="sdpa", # default value, pytorch native attention
),
),
"mini_gemma": MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_gemma,
model_class=GemmaForCausalLM,
mini_model_config=GemmaConfig(
vocab_size=32000, # 256000
hidden_size=1024, # 3072
intermediate_size=2048, # 24576
num_hidden_layers=4, # 28
num_attention_heads=4, # 16
num_key_value_heads=4, # 16
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation=None,
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-06,
use_cache=True,
pad_token_id=0,
# Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
# https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
bos_token_id=1, # 128000
eos_token_id=2, # 128001
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
),
),
"mini_mistral": MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_mistral,
model_class=MistralForCausalLM,
Expand Down Expand Up @@ -172,9 +202,16 @@ def run_mini_model(
set_seed(42)

if with_liger is True:
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(
rope=True, rms_norm=True, cross_entropy=True, swiglu=True
)
kwargs = {
"rope": True,
"rms_norm": True,
"cross_entropy": True,
}
if model_name == "mini_gemma":
kwargs["geglu"] = True
else:
kwargs["swiglu"] = True
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs)

model = create_model(model_name).to(dtype).to("cuda")
train_dataset = load_from_disk(DEFAULT_DATASET_PATH)
Expand All @@ -201,6 +238,9 @@ def run_mini_model(
@pytest.mark.parametrize(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
[
("mini_gemma", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
# mini_gemma has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
("mini_gemma", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5),
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 1e-4, 1e-5, 2e-3, 1e-5),
("mini_llama3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
# TODO: torch 2.5.0 nightly breaks mixtral test, but torch 2.3.0 works fine
Expand Down
Loading