Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Liger (Linkedin GPU Efficient Runtime) Kernel to Trainer #32860

Merged
merged 25 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9773382
add liger integration
JasonZhu1313 Aug 16, 2024
e94e09a
fix syntax
JasonZhu1313 Aug 17, 2024
20c78c8
fix import issue
JasonZhu1313 Aug 17, 2024
f44157f
add trainer.md
JasonZhu1313 Aug 19, 2024
f4e9747
Use _apply_liger_kernel()
shimizust Aug 19, 2024
38e2acd
Fixed log message
shimizust Aug 19, 2024
f27fdce
Update docs/source/en/trainer.md
shimizust Aug 20, 2024
a74ca24
Update docs/source/en/trainer.md
shimizust Aug 20, 2024
d3d29f4
Update src/transformers/training_args.py
shimizust Aug 20, 2024
29b13a9
Update src/transformers/trainer.py
shimizust Aug 20, 2024
f0b2125
Update src/transformers/training_args.py
shimizust Aug 20, 2024
8639629
Update docs/source/en/trainer.md
shimizust Aug 20, 2024
2d7c4ab
Fixed checkstyle and updated readme
shimizust Aug 20, 2024
e51eb93
Added test
shimizust Aug 20, 2024
c286e16
Fixed checkstyle
shimizust Aug 20, 2024
fc05ba6
fix docstring
JasonZhu1313 Aug 20, 2024
b2bae31
rename use_liger to use_liger_kernel
JasonZhu1313 Aug 20, 2024
d0b4be4
Trigger Build
JasonZhu1313 Aug 20, 2024
f2af439
Merge branch 'huggingface:main' into jaszhu/liger-kernel
JasonZhu1313 Aug 20, 2024
59a900b
Added test
shimizust Aug 21, 2024
7a88b06
Merge branch 'huggingface:main' into jaszhu/liger-kernel
JasonZhu1313 Aug 21, 2024
c2756e9
Merge branch 'huggingface:main' into jaszhu/liger-kernel
JasonZhu1313 Aug 22, 2024
62eff43
add fix-copies
JasonZhu1313 Aug 22, 2024
d3ae400
Merge branch 'huggingface:main' into jaszhu/liger-kernel
JasonZhu1313 Aug 22, 2024
eaf602b
Fixed copy inconsistencies
shimizust Aug 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,41 @@ trainer.train()

Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.

## Liger Kernel

[Liger-Kernel](https://github.com/linkedin/Liger-Kernel) Kernel is a collection of Triton kernels developed by Linkedin designed specifically for LLM training. We have implemented Hugging Face Compatible RMSNorm, RoPE, SwiGLU, CrossEntropy, FusedLinearCrossEntropy, and more to come. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The kernel works out of the box with flash attention, PyTorch FSDP, and Microsoft DeepSpeed.

<Tip>
Gain +20% throughput and reduce memory usage by 60% on LLaMA 3-8B model training. Achieve longer context lengths and larger batch sizes. It’s also useful if you want to scale up your model to multi-head training or large vocabulary sizes. Unleash multi-head training (medusa) and more. See details and examples in [Liger](https://github.com/linkedin/Liger-Kernel/tree/main/examples)
</Tip>

First make sure to install Liger official repository:
```bash
pip install liger-kernel
```

You should pass `use_liger_kernel=True` to apply liger kernel on your model, for example:

```py
from transformers import TrainingArguments

training_args = TrainingArguments(
output_dir="your-model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=True,
use_liger_kernel=True
)
```

The kernel supports the Llama, Gemma, Mistral, and Mixtral model architectures. The most up-to-date list of supported models can be found [here](https://github.com/linkedin/Liger-Kernel). When `use_liger_kernel` is set to `True`, the corresponding layers in the original model will be patched with Liger's efficient implementation, so you don't need to do anything extra other than setting the argument value.

## LOMO optimizer

The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
is_keras_nlp_available,
is_levenshtein_available,
is_librosa_available,
is_liger_kernel_available,
is_lomo_available,
is_natten_available,
is_nltk_available,
Expand Down Expand Up @@ -1162,6 +1163,13 @@ def require_librosa(test_case):
return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)


def require_liger_kernel(test_case):
"""
Decorator marking a test that requires liger_kernel
"""
return unittest.skipUnless(is_liger_kernel_available(), "test requires liger_kernel")(test_case)


def require_essentia(test_case):
"""
Decorator marking a test that requires essentia
Expand Down
19 changes: 19 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
is_grokadamw_available,
is_in_notebook,
is_ipex_available,
is_liger_kernel_available,
is_lomo_available,
is_peft_available,
is_safetensors_available,
Expand Down Expand Up @@ -463,6 +464,24 @@ def __init__(
" to `True` to avoid any unexpected behavior such as device placement mismatching."
)

if self.args.use_liger_kernel:
if is_liger_kernel_available():
from liger_kernel.transformers.trainer_integration import _apply_liger_kernel

model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
if model_type:
# Monkey patch the model with liger kernels. Use the default kernel configurations.
_apply_liger_kernel(model_type=model_type)
else:
logger.warning(
"The model does not have a valid `model_type` specified. No liger kernels will be applied."
)
else:
raise ImportError(
"You have set `use_liger_kernel` to `True` but liger-kernel >= 0.1.0 is not available. "
"Please install it with `pip install liger-kernel`"
)

_is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
model, "_hf_peft_config_loaded", False
)
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,11 @@ class TrainingArguments:

eval_use_gather_object (`bool`, *optional*, defaults to `False`):
Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. This should only be enabled if users are not just returning tensors, and this is actively discouraged by PyTorch.

use_liger_kernel (`bool`, *optional*, defaults to `False`):
Whether enable [Liger](https://github.com/linkedin/Liger-Kernel) Kernel for LLM model training.
It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with
flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models.
"""

framework = "pt"
Expand Down Expand Up @@ -1492,6 +1497,11 @@ class TrainingArguments:
},
)

use_liger_kernel: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to enable the Liger Kernel for model training."},
)

eval_use_gather_object: Optional[bool] = field(
default=False,
metadata={
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
is_keras_nlp_available,
is_levenshtein_available,
is_librosa_available,
is_liger_kernel_available,
is_lomo_available,
is_mlx_available,
is_natten_available,
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_torchvision_available = _is_package_available("torchvision")
_mlx_available = _is_package_available("mlx")
_hqq_available = _is_package_available("hqq")
_liger_kernel_available = _is_package_available("liger_kernel")


_torch_version = "N/A"
Expand Down Expand Up @@ -1164,6 +1165,13 @@ def is_mlx_available():
return _mlx_available


def is_liger_kernel_available():
if not _liger_kernel_available:
return False

return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.1.0")


# docstyle-ignore
AV_IMPORT_ERROR = """
{0} requires the PyAv library but it was not found in your environment. You can install it with:
Expand Down
20 changes: 20 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
require_galore_torch,
require_grokadamw,
require_intel_extension_for_pytorch,
require_liger_kernel,
require_lomo,
require_optuna,
require_peft,
Expand Down Expand Up @@ -1324,6 +1325,25 @@ def test_get_eval_dataloader_with_persistent_workers(self):
self.assertEqual(first_dataloader, first_dataloader_repeated)
self.assertEqual(second_dataloader, second_dataloader_repeated)

@require_liger_kernel
def test_apply_liger_kernel(self):
# Test that the model code actually gets patched with Liger kernel
from liger_kernel.transformers.rms_norm import LigerRMSNorm

from transformers.models.llama import modeling_llama

config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_model = LlamaForCausalLM(config)

args = TrainingArguments(
"./test",
use_liger_kernel=True,
)
Trainer(tiny_model, args)

# Check that one of the Llama model layers has been correctly patched with Liger kernel
self.assertEqual(modeling_llama.LlamaRMSNorm, LigerRMSNorm)

JasonZhu1313 marked this conversation as resolved.
Show resolved Hide resolved
@require_lomo
@require_torch_gpu
def test_lomo(self):
Expand Down
Loading