diff --git a/Makefile b/Makefile index 6b00d4cd..6c1c9190 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,6 @@ all: test test-convergence checkstyle # Command to run pytest for correctness tests test: python -m pytest --disable-warnings test/ --ignore=test/convergence - # Command to run flake8 (code style check), isort (import ordering), and black (code formatting) # Subsequent commands still run if the previous fails, but return failure at the end diff --git a/README.md b/README.md index 2367c392..24ca197b 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ -[![Downloads](https://static.pepy.tech/badge/liger-kernel)](https://pepy.tech/project/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel.svg)](https://badge.fury.io/py/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel-nightly.svg)](https://badge.fury.io/py/liger-kernel-nightly) -[![](https://dcbadge.vercel.app/api/server/cudamode?style=flat)](https://discord.gg/CX2YmNmn) +[![Downloads](https://static.pepy.tech/badge/liger-kernel)](https://pepy.tech/project/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel.svg)](https://badge.fury.io/py/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel-nightly.svg)](https://badge.fury.io/py/liger-kernel-nightly) +[![](https://dcbadge.vercel.app/api/server/cudamode?style=flat)](https://discord.gg/CX2YmNmn) @@ -33,8 +33,8 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and | ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) | > **Note:** -> - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s. -> - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K. +> - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s. +> - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K. ## Examples @@ -72,7 +72,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and - `torch >= 2.1.2` - `triton >= 2.3.0` -- `transformers >= 4.40.1` +- `transformers >= 4.41.0` > **Note:** > Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton). @@ -80,7 +80,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and To install the stable version: ```bash -$ pip install liger-kernel +$ pip install liger-kernel ``` To install the nightly version: @@ -109,7 +109,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_llama model = transformers.AutoModelForCausalLM.from_pretrained("") # Adding this line automatically monkey-patches the model with the optimized Liger kernels -apply_liger_kernel_to_llama() +apply_liger_kernel_to_llama() ``` ### 2. Compose Your Own Model @@ -161,6 +161,8 @@ loss.backward() | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss | | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss | | Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss | + ### Kernels @@ -175,11 +177,11 @@ loss.backward() | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`| - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction. -- **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction. -- **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by +- **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction. +- **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$ , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. -- **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by +- **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$ , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used. - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.). @@ -188,7 +190,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$ -> **Note:** +> **Note:** > Reported speedups and memory reductions are with respect to the LLaMA 3-8B Hugging Face layer implementations. All models use 4K hidden size and 4K sequence length and are evaluated based on memory usage and wall time for the forward+backward pass on a single NVIDIA A100 80G GPU using small batch sizes. Liger kernels exhibit more efficient scaling to larger batch sizes, detailed further in the [Benchmark](./benchmark) folder. ## Note on ML Compiler @@ -202,7 +204,7 @@ Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compil | Torch Compile | 3780 | 66.4 | | Torch Compile + Liger Kernel | 3702 | 31.0 | -> **Note:** +> **Note:** > 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s. > 2. Tested on torch `2.5.0.dev20240731+cu118` diff --git a/setup.py b/setup.py index d678220c..ca101d2b 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ install_requires=[ "torch>=2.1.2", "triton>=2.3.0", - "transformers>=4.40.1", + "transformers>=4.41.0", ], extras_require={ "dev": [ diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 4c6148d4..836823ac 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -3,5 +3,6 @@ apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, + apply_liger_kernel_to_phi3, apply_liger_kernel_to_qwen2, ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 8ef96f7b..f976981e 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -7,7 +7,11 @@ from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb -from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerSwiGLUMLP +from liger_kernel.transformers.swiglu import ( + LigerBlockSparseTop2MLP, + LigerPhi3SwiGLUMLP, + LigerSwiGLUMLP, +) def apply_liger_kernel_to_llama( @@ -181,3 +185,30 @@ def apply_liger_kernel_to_qwen2( modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward if swiglu: modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP + + +def apply_liger_kernel_to_phi3( + rope: bool = True, + cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Phi3 models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True. + """ + from transformers.models.phi3 import modeling_phi3 + + if rope: + modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma + if rms_norm: + modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama + if swiglu: + modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP + if cross_entropy: + modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss diff --git a/src/liger_kernel/transformers/swiglu.py b/src/liger_kernel/transformers/swiglu.py index ebf1f0c0..42f4df10 100644 --- a/src/liger_kernel/transformers/swiglu.py +++ b/src/liger_kernel/transformers/swiglu.py @@ -38,3 +38,27 @@ def __init__(self, config): def forward(self, x): return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x))) + + +class LigerPhi3SwiGLUMLP(nn.Module): + """ + Patch Phi3MLP to use LigerSiLUMulFunction + https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/phi3/modeling_phi3.py#L241 + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = nn.Linear( + self.hidden_size, 2 * self.intermediate_size, bias=False + ) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + if config.hidden_act not in ["silu", "swish"]: + raise ValueError(f"Activation function {config.hidden_act} not supported.") + + def forward(self, x): + up_states = self.gate_up_proj(x) + gate, up_states = up_states.chunk(2, dim=-1) + return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states)) diff --git a/src/liger_kernel/transformers/trainer_integration.py b/src/liger_kernel/transformers/trainer_integration.py index 4caf0317..b943404d 100644 --- a/src/liger_kernel/transformers/trainer_integration.py +++ b/src/liger_kernel/transformers/trainer_integration.py @@ -5,6 +5,7 @@ apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, + apply_liger_kernel_to_phi3, ) logger = logging.getLogger(__name__) @@ -15,6 +16,7 @@ "llama": apply_liger_kernel_to_llama, "mistral": apply_liger_kernel_to_mistral, "mixtral": apply_liger_kernel_to_mixtral, + "phi3": apply_liger_kernel_to_phi3, } diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index 62cf58a9..4ad791fd 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -15,6 +15,7 @@ from transformers.models.llama import LlamaConfig, LlamaForCausalLM from transformers.models.mistral import MistralConfig, MistralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM +from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from liger_kernel.transformers import ( @@ -22,6 +23,7 @@ apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, + apply_liger_kernel_to_phi3, apply_liger_kernel_to_qwen2, ) @@ -176,6 +178,30 @@ attn_implementation="sdpa", # default value, pytorch native attention ), ), + "mini_phi3": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_phi3, + model_class=Phi3ForCausalLM, + mini_model_config=Phi3Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, # 32000 + hidden_act="silu", + hidden_size=896, # 3072 + initializer_range=0.02, + intermediate_size=4864, # 8192 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=None, # defaults to num_attention_heads + rms_norm_eps=1e-5, + rope_theta=10000.0, + sliding_window=None, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32064, + attn_implementation="eager", + ), + ), } @@ -253,6 +279,8 @@ def run_mini_model( ("mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), ("mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), + ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_phi3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), ], ) def test_mini_model( diff --git a/test/transformers/test_swiglu.py b/test/transformers/test_swiglu.py index 14132c2a..0b8ef3d4 100644 --- a/test/transformers/test_swiglu.py +++ b/test/transformers/test_swiglu.py @@ -2,14 +2,21 @@ import torch from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaMLP +from transformers.models.phi3.configuration_phi3 import Phi3Config +from transformers.models.phi3.modeling_phi3 import Phi3MLP -from liger_kernel.transformers.swiglu import LigerSwiGLUMLP +from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP, LigerSwiGLUMLP LLAMA_CONFIG = LlamaConfig( hidden_size=4096, intermediate_size=11008, hidden_act="silu", ) +PHI3_CONFIG = Phi3Config( + hidden_size=4096, + intermediate_size=11008, + hidden_act="silu", +) SLEEP_SECONDS = 0.1 @@ -33,7 +40,9 @@ (torch.bfloat16, 1e4, 1e-2), ], ) -def test_correctness(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol): +def test_correctness_llamamlp( + bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol +): _input = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype) @@ -58,39 +67,97 @@ def test_correctness(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, y1 = llama_mlp(x1) y2 = liger_mlp(x2) - assert torch.allclose(y1, y2, atol=atol, rtol=rtol) is True + assert torch.allclose(y1, y2, atol=atol, rtol=rtol) dy = torch.randn_like(y1) y1.backward(dy.clone(), retain_graph=True) y2.backward(dy.clone(), retain_graph=True) - assert ( - torch.allclose( - llama_mlp.gate_proj.weight.grad, - liger_mlp.gate_proj.weight.grad, - atol=atol, - rtol=rtol, - ) - is True + assert torch.allclose( + llama_mlp.gate_proj.weight.grad, + liger_mlp.gate_proj.weight.grad, + atol=atol, + rtol=rtol, ) - assert ( - torch.allclose( - llama_mlp.up_proj.weight.grad, - liger_mlp.up_proj.weight.grad, - atol=atol, - rtol=rtol, - ) - is True + assert torch.allclose( + llama_mlp.up_proj.weight.grad, + liger_mlp.up_proj.weight.grad, + atol=atol, + rtol=rtol, + ) + assert torch.allclose( + llama_mlp.down_proj.weight.grad, + liger_mlp.down_proj.weight.grad, + atol=atol, + rtol=rtol, + ) + + assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "bsz, seq_len, hidden_size, intermediate_size", + [ + (2, 2048, 4096, 11008), + (2, 2048, 2048, 4096), + # weird shapes + (9, 41, 341, 4231), + (6, 42, 256, 2048), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + # atol is for small values: they have more difference, so set atol higher + # rtol is for larger values: they are very close, so set rtol lower + (torch.float32, 1e-0, 1e-5), + # TODO: we should find a better way to tune this. 1e4 is too large apparently + (torch.bfloat16, 1e4, 1e-2), + ], +) +def test_correctness_phi3mlp( + bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol +): + + _input = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype) + + x1 = _input.clone().requires_grad_(True) + x2 = _input.clone().requires_grad_(True) + + # initialize weights + GU = torch.randn(hidden_size, intermediate_size * 2, device="cuda", dtype=dtype) + D = torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype) + + phi3_mlp = Phi3MLP(config=PHI3_CONFIG).to("cuda").to(dtype) + phi3_mlp.gate_up_proj.weight.data = GU.T + phi3_mlp.down_proj.weight.data = D.T + + liger_mlp = LigerPhi3SwiGLUMLP(config=PHI3_CONFIG).to("cuda").to(dtype) + liger_mlp.gate_up_proj.weight.data = GU.T + liger_mlp.down_proj.weight.data = D.T + + y1 = phi3_mlp(x1) + y2 = liger_mlp(x2) + + assert torch.allclose(y1, y2, atol=atol, rtol=rtol) + + dy = torch.randn_like(y1) + + y1.backward(dy.clone(), retain_graph=True) + y2.backward(dy.clone(), retain_graph=True) + + assert torch.allclose( + phi3_mlp.gate_up_proj.weight.grad, + liger_mlp.gate_up_proj.weight.grad, + atol=atol, + rtol=rtol, ) - assert ( - torch.allclose( - llama_mlp.down_proj.weight.grad, - liger_mlp.down_proj.weight.grad, - atol=atol, - rtol=rtol, - ) - is True + assert torch.allclose( + phi3_mlp.down_proj.weight.grad, + liger_mlp.down_proj.weight.grad, + atol=atol, + rtol=rtol, ) - assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) is True + assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) diff --git a/test/transformers/test_trainer_integration.py b/test/transformers/test_trainer_integration.py index e200b006..397cc0db 100644 --- a/test/transformers/test_trainer_integration.py +++ b/test/transformers/test_trainer_integration.py @@ -20,6 +20,8 @@ def test_apply_liger_kernel_only_supported_model_type_called(): mock_gemma = Mock() mock_llama = Mock() mock_mistral = Mock() + mock_mixtral = Mock() + mock_phi3 = Mock() with patch.dict( MODEL_TYPE_TO_APPLY_LIGER_FN, @@ -27,12 +29,16 @@ def test_apply_liger_kernel_only_supported_model_type_called(): "gemma": mock_gemma, "llama": mock_llama, "mistral": mock_mistral, + "mixtral": mock_mixtral, + "phi3": mock_phi3, }, ): _apply_liger_kernel("llama") mock_llama.assert_called_once() mock_gemma.assert_not_called() mock_mistral.assert_not_called() + mock_mixtral.assert_not_called() + mock_phi3.assert_not_called() def test_apply_liger_kernel_passes_kwargs(): diff --git a/test/transformers/test_transformers_monkey_patch.py b/test/transformers/test_transformers_monkey_patch.py index 2af747da..9443d56c 100644 --- a/test/transformers/test_transformers_monkey_patch.py +++ b/test/transformers/test_transformers_monkey_patch.py @@ -8,6 +8,7 @@ def test_import_from_root(): apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, + apply_liger_kernel_to_phi3, apply_liger_kernel_to_qwen2, ) except Exception: