Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monkeypatch for Phi3 #76

Merged
merged 13 commits into from
Aug 27, 2024
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"isort>=5.13.2",
"pre-commit>=3.7.1",
"torch-tb-profiler>=0.4.1",
"pytest>=8.3.2",
tyler-romero marked this conversation as resolved.
Show resolved Hide resolved
]
},
)
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
apply_liger_kernel_to_qwen2,
)
33 changes: 32 additions & 1 deletion src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerSwiGLUMLP
from liger_kernel.transformers.swiglu import (
LigerBlockSparseTop2MLP,
LigerPhi3SwiGLUMLP,
LigerSwiGLUMLP,
)


def apply_liger_kernel_to_llama(
Expand Down Expand Up @@ -167,3 +171,30 @@ def apply_liger_kernel_to_qwen2(
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
if swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP


def apply_liger_kernel_to_phi3(
rope: bool = True,
cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.

Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
"""
from transformers.models.phi3 import modeling_phi3

if rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
if rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
if swiglu:
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
if cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
24 changes: 24 additions & 0 deletions src/liger_kernel/transformers/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,27 @@ def __init__(self, config):
def forward(self, x):

return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))


class LigerPhi3SwiGLUMLP(nn.Module):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So Llama has its own MLP implementation here, but its named very generally. I went for a model-specific name here, but open to suggestions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah. this is necessary

"""
Patch Phi3MLP to use LigerSiLUMulFunction
https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/phi3/modeling_phi3.py#L241
"""

def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_up_proj = nn.Linear(
self.hidden_size, 2 * self.intermediate_size, bias=False
)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
if config.hidden_act not in ["silu", "swish"]:
raise ValueError(f"Activation function {config.hidden_act} not supported.")

def forward(self, x):
up_states = self.gate_up_proj(x)
gate, up_states = up_states.chunk(2, dim=-1)
return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
2 changes: 2 additions & 0 deletions src/liger_kernel/transformers/trainer_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
)

logger = logging.getLogger(__name__)
Expand All @@ -15,6 +16,7 @@
"llama": apply_liger_kernel_to_llama,
"mistral": apply_liger_kernel_to_mistral,
"mixtral": apply_liger_kernel_to_mixtral,
"phi3": apply_liger_kernel_to_phi3,
}


Expand Down
28 changes: 28 additions & 0 deletions test/convergence/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
from transformers.models.mistral import MistralConfig, MistralForCausalLM
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM

from liger_kernel.transformers import (
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
apply_liger_kernel_to_qwen2,
)

Expand Down Expand Up @@ -144,6 +146,30 @@
attn_implementation="sdpa", # default value, pytorch native attention
),
),
"mini_phi3": MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_phi3,
model_class=Phi3ForCausalLM,
mini_model_config=Phi3Config(
attention_dropout=0.0,
bos_token_id=1,
eos_token_id=2, # 32000
hidden_act="silu",
hidden_size=896, # 3072
initializer_range=0.02,
intermediate_size=4864, # 8192
max_position_embeddings=4096,
num_attention_heads=8, # 32
num_hidden_layers=4, # 32
num_key_value_heads=None, # defaults to num_attention_heads
rms_norm_eps=1e-5,
rope_theta=10000.0,
sliding_window=None,
tie_word_embeddings=False,
use_cache=True,
vocab_size=32064,
attn_implementation="eager",
),
),
}


Expand Down Expand Up @@ -211,6 +237,8 @@ def run_mini_model(
("mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_phi3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
],
)
def test_mini_model(
Expand Down
86 changes: 84 additions & 2 deletions test/transformers/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@
import torch
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaMLP
from transformers.models.phi3.configuration_phi3 import Phi3Config
from transformers.models.phi3.modeling_phi3 import Phi3MLP

from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP, LigerSwiGLUMLP

LLAMA_CONFIG = LlamaConfig(
hidden_size=4096,
intermediate_size=11008,
hidden_act="silu",
)
PHI3_CONFIG = Phi3Config(
hidden_size=4096,
intermediate_size=11008,
hidden_act="silu",
)
SLEEP_SECONDS = 0.1


Expand All @@ -33,7 +40,9 @@
(torch.bfloat16, 1e4, 1e-2),
],
)
def test_correctness(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol):
def test_correctness_llamamlp(
bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol
):

_input = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype)

Expand Down Expand Up @@ -94,3 +103,76 @@ def test_correctness(bsz, seq_len, hidden_size, intermediate_size, dtype, atol,
)

assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) is True


@pytest.mark.parametrize(
"bsz, seq_len, hidden_size, intermediate_size",
[
(2, 2048, 4096, 11008),
(2, 2048, 2048, 4096),
# weird shapes
(9, 41, 341, 4231),
(6, 42, 256, 2048),
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
# atol is for small values: they have more difference, so set atol higher
# rtol is for larger values: they are very close, so set rtol lower
(torch.float32, 1e-0, 1e-5),
# TODO: we should find a better way to tune this. 1e4 is too large apparently
(torch.bfloat16, 1e4, 1e-2),
],
)
def test_correctness_phi3mlp(
bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol
):

_input = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype)

x1 = _input.clone().requires_grad_(True)
x2 = _input.clone().requires_grad_(True)

# initialize weights
GU = torch.randn(hidden_size, intermediate_size * 2, device="cuda", dtype=dtype)
D = torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype)

phi3_mlp = Phi3MLP(config=PHI3_CONFIG).to("cuda").to(dtype)
phi3_mlp.gate_up_proj.weight.data = GU.T
phi3_mlp.down_proj.weight.data = D.T

liger_mlp = LigerPhi3SwiGLUMLP(config=PHI3_CONFIG).to("cuda").to(dtype)
liger_mlp.gate_up_proj.weight.data = GU.T
liger_mlp.down_proj.weight.data = D.T

y1 = phi3_mlp(x1)
y2 = liger_mlp(x2)

assert torch.allclose(y1, y2, atol=atol, rtol=rtol) is True

dy = torch.randn_like(y1)

y1.backward(dy.clone(), retain_graph=True)
y2.backward(dy.clone(), retain_graph=True)

assert (
torch.allclose(
phi3_mlp.gate_up_proj.weight.grad,
liger_mlp.gate_up_proj.weight.grad,
atol=atol,
rtol=rtol,
)
is True
)
assert (
torch.allclose(
phi3_mlp.down_proj.weight.grad,
liger_mlp.down_proj.weight.grad,
atol=atol,
rtol=rtol,
)
is True
tyler-romero marked this conversation as resolved.
Show resolved Hide resolved
)

assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) is True
12 changes: 11 additions & 1 deletion test/transformers/test_trainer_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,25 @@ def test_apply_liger_kernel_only_supported_model_type_called():
mock_gemma = Mock()
mock_llama = Mock()
mock_mistral = Mock()
mock_mixtral = Mock()
mock_phi3 = Mock()

with patch.dict(
MODEL_TYPE_TO_APPLY_LIGER_FN,
{"gemma": mock_gemma, "llama": mock_llama, "mistral": mock_mistral},
{
"gemma": mock_gemma,
"llama": mock_llama,
"mistral": mock_mistral,
"mixtral": mock_mixtral,
"phi3": mock_phi3,
},
):
_apply_liger_kernel("llama")
mock_llama.assert_called_once()
mock_gemma.assert_not_called()
mock_mistral.assert_not_called()
mock_mixtral.assert_not_called()
mock_phi3.assert_not_called()


def test_apply_liger_kernel_passes_kwargs():
Expand Down
1 change: 1 addition & 0 deletions test/transformers/test_transformers_monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def test_import_from_root():
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
apply_liger_kernel_to_qwen2,
)
except Exception:
Expand Down