-
Notifications
You must be signed in to change notification settings - Fork 67
enable fine tuning on HPU #552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
splotnikv
wants to merge
1
commit into
instructlab:main
Choose a base branch
from
splotnikv:hpu_ft_pub
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# InstructLab Training on HPU | ||
|
||
## HPU specific changes | ||
Next changes are required to enable training on HPU: | ||
|
||
|GPU|HPU| | ||
|---|---| | ||
|`from accelerate import Accelerator` | `from optimum.habana.accelerate import GaudiAccelerator`| | ||
|`from accelerate.utils import FullyShardedDataParallelPlugin` | `from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin` | | ||
|
||
It is also recommended to use HPU optimized versions of transformers: | ||
|
||
```Python | ||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi | ||
adapt_transformers_to_gaudi() | ||
``` | ||
|
||
## Bucketing | ||
Multipack sampler implementation produces wide range of batches with different sample lengths and number of samples. Each of these combinations leads to graph recompilation and this recompilation takes time and slows down training. To reduce number of recompilations HPU implementation uses bucketing approach, when maximum sample length in batch is aligned to some predefined value. It is similar to padding but all samples in the batch are padded not to the longest sample but to the some slightly bigger value. | ||
|
||
 | ||
|
||
|
||
To compute bucked size, we use next algorithm: | ||
- Firstly, we find MSB of the longest sample in the batch, let's call it S. | ||
- Then we slice the range [2 ** S, 2 ** (S+1)] into 16 buckets of the same size. | ||
- Then we use top boundary of the smallest suitable bucked as padding value. | ||
|
||
This approach limits overhead of the bucketing to 1/16 th of the longest sample and allows us to significantly reduce number of recompilations. | ||
|
||
## How to run | ||
To run training build docker using next dockerfile: | ||
```Dockerfile | ||
FROM vault.habana.ai/gaudi-docker/1.21.0/rhel9.4/habanalabs/pytorch-installer-2.6.0:1.21.0-555 | ||
|
||
ARG CMAKE_ARGS="-DGGML_NATIVE=off" | ||
|
||
WORKDIR /app | ||
RUN pip install git+https://github.com/instructlab/instructlab.git@v0.26.1 | ||
|
||
WORKDIR /app | ||
RUN pip install git+https://github.com/huggingface/optimum-habana.git@v1.18.0 | ||
``` | ||
|
||
Then make next changes to config file: | ||
```YAML | ||
train: | ||
device: hpu | ||
distributed_backend: fsdp | ||
fsdp_cpu_offload_optimizer: false | ||
is_padding_free: true | ||
pipeline: accelerated | ||
disable_flash_attn: true | ||
``` | ||
|
||
And finally run this command line: | ||
```BASH | ||
ilab --config=./config.yaml model train --pipeline accelerated --data-path ./data.jsonl | ||
``` | ||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import torch | ||
from functools import lru_cache | ||
|
||
|
||
@lru_cache(maxsize=None) | ||
def is_torch_hpu_available() -> bool: | ||
try: | ||
import habana_frameworks.torch.core # noqa: F401 | ||
except ImportError: | ||
return False | ||
return True | ||
|
||
|
||
def simple_bucket(length): | ||
""" | ||
This bucket algorithm merely relies on the given number instead of based on | ||
slicing the known (min, max) range for several reasons: | ||
1) Due to the use of the first-fit-decreasing (FFD) algorithm, the | ||
(min, max) sequence length of each rank will be much smaller than the | ||
(min, max) sequence length of the dataset. Bucketing on the | ||
(min, max) sequence length of the dataset is not practical | ||
2) The (min, max) sequence length of a given rank is unknown until | ||
finishing 1 epoch since the packing is done on the fly | ||
3) Due to the shuffling, the (min, max) sequence length of a given rank | ||
may vary between ranks. Once the (min, max) sequence length of a | ||
given rank changes, the bucketing also needs adjustment | ||
|
||
This bucket algorithm is based on the most significant set bit of the input number. | ||
It first check what’s the most significant set bit, assuming it's bit "S", | ||
and then slice the range [2 ** S, 2 ** (S+1)] into buckets with the same size. | ||
By default the range is divided into 16 buckets, so the bucket size will be | ||
2 ** (S - 4) | ||
For example, 0b10001 will be padded to 0b10010. | ||
This approach can limit the overhead of bucketing (at most 1/16 of the input | ||
number) and also prevent recompilation due to a too small bucket size. | ||
""" | ||
l = length | ||
msb = 0 | ||
while l > 0: | ||
msb += 1 | ||
l = l // 2 | ||
|
||
align = (1 << (msb - 4)) if msb >= 4 else 1 | ||
|
||
return (length + align - 1) // align * align | ||
|
||
|
||
def bucket(length): | ||
return simple_bucket(length) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain to me why we need the bucketing algorithm? I see in
main_ds.py
you are settinglazy_mode=False
which would mean we are using eager compilation, and afaik, eager mode in torch supports dynamically shaped tensors (which I am assuming is the case for habana torch too). I would really appreciate it if you can shed some light on this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. Even for eager mode we recompile graph if shapes have changed. I'll find out more details and post them here.