Skip to content

enable fine tuning on HPU #552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions docs/hpu.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# InstructLab Training on HPU

## HPU specific changes
Next changes are required to enable training on HPU:

|GPU|HPU|
|---|---|
|`from accelerate import Accelerator` | `from optimum.habana.accelerate import GaudiAccelerator`|
|`from accelerate.utils import FullyShardedDataParallelPlugin` | `from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin` |

It is also recommended to use HPU optimized versions of transformers:

```Python
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi()
```

## Bucketing
Multipack sampler implementation produces wide range of batches with different sample lengths and number of samples. Each of these combinations leads to graph recompilation and this recompilation takes time and slows down training. To reduce number of recompilations HPU implementation uses bucketing approach, when maximum sample length in batch is aligned to some predefined value. It is similar to padding but all samples in the batch are padded not to the longest sample but to the some slightly bigger value.

![bucketing vs. padding](./hpu_pic/bucketing_vs_padding.png)


To compute bucked size, we use next algorithm:
- Firstly, we find MSB of the longest sample in the batch, let's call it S.
- Then we slice the range [2 ** S, 2 ** (S+1)] into 16 buckets of the same size.
- Then we use top boundary of the smallest suitable bucked as padding value.

This approach limits overhead of the bucketing to 1/16 th of the longest sample and allows us to significantly reduce number of recompilations.

## How to run
To run training build docker using next dockerfile:
```Dockerfile
FROM vault.habana.ai/gaudi-docker/1.21.0/rhel9.4/habanalabs/pytorch-installer-2.6.0:1.21.0-555

ARG CMAKE_ARGS="-DGGML_NATIVE=off"

WORKDIR /app
RUN pip install git+https://github.com/instructlab/instructlab.git@v0.26.1

WORKDIR /app
RUN pip install git+https://github.com/huggingface/optimum-habana.git@v1.18.0
```

Then make next changes to config file:
```YAML
train:
device: hpu
distributed_backend: fsdp
fsdp_cpu_offload_optimizer: false
is_padding_free: true
pipeline: accelerated
disable_flash_attn: true
```

And finally run this command line:
```BASH
ilab --config=./config.yaml model train --pipeline accelerated --data-path ./data.jsonl
```


Binary file added docs/hpu_pic/bucketing_vs_padding.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 12 additions & 1 deletion src/instructlab/training/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Callable, Optional

# Third Party
from accelerate import Accelerator as TransformersAccel
from torch.utils.data import DataLoader
from transformers import get_scheduler
import torch
Expand Down Expand Up @@ -32,6 +31,7 @@ def __init__(
deepspeed_cpu_offload_optimizer_pin_memory: Optional[bool] = False,
deepspeed_cpu_offload_optimizer_ratio: Optional[float] = None,
fsdp_cpu_offload_params: Optional[bool] = False,
device: Optional[str] = None,
):
self.samples_per_gpu = samples_per_gpu
self.save_samples = save_samples
Expand All @@ -48,6 +48,7 @@ def __init__(
deepspeed_cpu_offload_optimizer_ratio
)
self.fsdp_cpu_offload_params = fsdp_cpu_offload_params
self.device_str = device

if self.distributed_framework == DistributedBackend.DEEPSPEED:
# Standard
Expand All @@ -69,6 +70,12 @@ def __init__(
"fsdp_plugin": self.get_fsdp_config(),
"mixed_precision": "bf16",
}

if device == "hpu":
from optimum.habana.accelerate import GaudiAccelerator as TransformersAccel
else:
from accelerate import Accelerator as TransformersAccel

self.accelerator = TransformersAccel(
**accel_args,
)
Expand Down Expand Up @@ -160,6 +167,10 @@ def get_fsdp_config(self):
cpu_offload=CPUOffload(self.fsdp_cpu_offload_params),
)

if self.device_str == "hpu":
fsdp_plugin.use_orig_params=True
fsdp_plugin.sync_module_states=True

# `use_orig_params` must be disabled when using LoRA and FSDP together
# Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts
if self.model.lora_config is not None:
Expand Down
2 changes: 2 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,5 @@ class TrainingArgs(BaseModel):
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(
default="INFO"
)

device: Optional[str] = None
49 changes: 49 additions & 0 deletions src/instructlab/training/hpu_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from functools import lru_cache


@lru_cache(maxsize=None)
def is_torch_hpu_available() -> bool:
try:
import habana_frameworks.torch.core # noqa: F401
except ImportError:
return False
return True


def simple_bucket(length):
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain to me why we need the bucketing algorithm? I see in main_ds.py you are setting lazy_mode=False which would mean we are using eager compilation, and afaik, eager mode in torch supports dynamically shaped tensors (which I am assuming is the case for habana torch too). I would really appreciate it if you can shed some light on this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Even for eager mode we recompile graph if shapes have changed. I'll find out more details and post them here.

This bucket algorithm merely relies on the given number instead of based on
slicing the known (min, max) range for several reasons:
1) Due to the use of the first-fit-decreasing (FFD) algorithm, the
(min, max) sequence length of each rank will be much smaller than the
(min, max) sequence length of the dataset. Bucketing on the
(min, max) sequence length of the dataset is not practical
2) The (min, max) sequence length of a given rank is unknown until
finishing 1 epoch since the packing is done on the fly
3) Due to the shuffling, the (min, max) sequence length of a given rank
may vary between ranks. Once the (min, max) sequence length of a
given rank changes, the bucketing also needs adjustment

This bucket algorithm is based on the most significant set bit of the input number.
It first check what’s the most significant set bit, assuming it's bit "S",
and then slice the range [2 ** S, 2 ** (S+1)] into buckets with the same size.
By default the range is divided into 16 buckets, so the bucket size will be
2 ** (S - 4)
For example, 0b10001 will be padded to 0b10010.
This approach can limit the overhead of bucketing (at most 1/16 of the input
number) and also prevent recompilation due to a too small bucket size.
"""
l = length
msb = 0
while l > 0:
msb += 1
l = l // 2

align = (1 << (msb - 4)) if msb >= 4 else 1

return (length + align - 1) // align * align


def bucket(length):
return simple_bucket(length)
75 changes: 62 additions & 13 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
UserWarning,
)

from instructlab.training.hpu_utils import is_torch_hpu_available

if is_torch_hpu_available():
import habana_frameworks.torch.core as htcore
import habana_frameworks.torch.distributed.hccl
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi()

# Third Party
from tqdm import tqdm
from transformers import AutoConfig
Expand Down Expand Up @@ -122,7 +130,7 @@ def train(
if local_rank == 0:
inner_pb = tqdm(range(num_epoch_steps), desc=f"Epoch {epoch}")

# blast through the batches in the train loader up to the last step within the epoch.
# blast through the batches in the train loader up to the last step within the epoch.
for batch in accelerator.train_loader:
if global_step <= args.last_step:
# in the case of resuming, last_step > 0
Expand All @@ -137,10 +145,19 @@ def train(
micro_batch_size = float(torch.tensor([batch.pop("num_samples")]))
total_length = float(torch.tensor([batch.pop("total_length")]))
for k in batch:
batch[k] = batch[k].to(local_rank)
batch[k] = batch[k].to('hpu' if args.device == "hpu" else local_rank)

hpu_args = {}
if args.device == "hpu":
hpu_args = {
"use_flash_attention":True,
"lazy_mode":False,
}

output = model(
**batch,
use_cache=False,
**hpu_args,
)
loss = output.loss
log_loss = loss.detach().item()
Expand Down Expand Up @@ -177,8 +194,14 @@ def train(
elapsed_time = time.time() - start
overall_throughput = args.samples_per_gpu * world_size / elapsed_time
current_lr = accelerator.lr_scheduler.get_last_lr()[0]
cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]

if args.device == "hpu":
mem_allocated = torch.hpu.memory_allocated() / (1024**3)
malloc_retries = 0
else:
mem_allocated = torch.cuda.memory_allocated() / (1024**3)
malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]

global_grad_norm = (
model.get_global_grad_norm()
if hasattr(model, "get_global_grad_norm")
Expand All @@ -200,8 +223,8 @@ def train(
"rank": torch.distributed.get_rank(),
"overall_throughput": overall_throughput,
"lr": current_lr,
"cuda_mem_allocated": cuda_mem_allocated,
"cuda_malloc_retries": cuda_malloc_retries,
("hpu" if args.device == "hpu" else "cuda") + "_mem_allocated": mem_allocated,
("hpu" if args.device == "hpu" else "cuda") + "_malloc_retries": malloc_retries,
"num_loss_counted_tokens": int(num_loss_counted_tokens),
"num_tokens_rank0": int(total_length),
"batch_size": int(micro_batch_size),
Expand Down Expand Up @@ -234,7 +257,10 @@ def train(
global_step += 1
if local_rank == 0:
inner_pb.update(1)
torch.cuda.empty_cache()

if args.device != "hpu":
torch.cuda.empty_cache()

if args.checkpoint_at_epoch:
base_logger.debug(f"Saving checkpoint at epoch {epoch}")
save_checkpoint(
Expand Down Expand Up @@ -312,17 +338,24 @@ def main(args):
args.model_type = model_conf.model_type

#### distributed init #####
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
if args.device == "hpu":
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
else:
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

args.local_rank = int(os.environ["LOCAL_RANK"])

timeout = _get_collective_timeout()
if timeout is not None:
torch.distributed.init_process_group(timeout=timeout)
else:
torch.distributed.init_process_group()
backend = "hccl" if args.device == "hpu" else None
torch.distributed.init_process_group(backend=backend, timeout=timeout)

args.global_rank = torch.distributed.get_rank()
tensor = torch.ByteTensor([False]).cuda()

if args.device == "hpu":
tensor = torch.ByteTensor([False]).to('hpu')
else:
tensor = torch.ByteTensor([False]).cuda()

torch.distributed.all_reduce(tensor)
torch.distributed.barrier()

Expand Down Expand Up @@ -369,6 +402,7 @@ def main(args):
flash_enabled=flash_enabled,
noise_alpha=args.NEFTune_alpha,
lora_quant_bits=args.lora_quant_bits,
device=args.device,
)

args.base_model_args = m.base_model_args
Expand Down Expand Up @@ -407,6 +441,7 @@ def main(args):
samples_per_gpu=args.samples_per_gpu,
sampler=args.sampler,
seed=args.seed,
device=args.device,
)
if len(train_loader) == 0:
# this happens sometimes when we have more GPUs than data to process. In this case
Expand All @@ -426,6 +461,7 @@ def main(args):
samples_per_gpu=args.samples_per_gpu,
sampler=args.sampler,
seed=args.seed,
device=args.device,
)

if args.local_rank == 0:
Expand Down Expand Up @@ -457,6 +493,7 @@ def main(args):
deepspeed_cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio,
fsdp_cpu_offload_params=args.cpu_offload_params_fsdp,
save_samples=args.save_samples,
device=args.device,
)
# optimizer needs model that has been prepared by accelerator
# and then accelerator needs to be prepared AGAIN once optimizer is initialized
Expand Down Expand Up @@ -636,6 +673,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
if train_args.keep_last_checkpoint_only:
command.append("--keep_last_checkpoint_only")

command.append(
f"--device={train_args.device}"
)

logger.info("Running training command as subprocess: %s", " ".join(command))
process = None
interrupt: KeyboardInterrupt | Exception | None = None
Expand Down Expand Up @@ -837,6 +878,14 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
action="store_true",
help="Use Liger kernels for training.",
)

parser.add_argument(
"--device",
type=str,
default=None,
help="PyTorch device to use.",
)

args = parser.parse_args()
set_random_seed(args.seed)
main(args)
Expand Down
Loading
Loading