Skip to content

[QEff Finetune] : Made fixes to training script #439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions QEfficient/cloud/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer

from QEfficient.finetune.configs.training import TrainConfig
from QEfficient.finetune.loss.loss_factory import get_loss
from QEfficient.finetune.utils.config_utils import (
generate_dataset_config,
generate_peft_config,
update_config,
)
from QEfficient.finetune.utils.dataset_utils import get_dataloader
from QEfficient.finetune.utils.helper import get_rank, is_rank_zero
from QEfficient.finetune.utils.parser import get_finetune_parser
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
from QEfficient.utils._utils import login_and_download_hf_lm
Expand Down Expand Up @@ -67,7 +69,7 @@ def setup_distributed_training(train_config: TrainConfig) -> None:
dist_backend_map = {"cpu": "gloo", "qaic": "qccl", "cuda": "gloo"}
dist.init_process_group(backend=dist_backend_map[torch_device.type])
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
getattr(torch, torch_device.type).set_device(dist.get_rank())
getattr(torch, torch_device.type).set_device(get_rank())


def setup_seeds(seed: int) -> None:
Expand Down Expand Up @@ -114,6 +116,7 @@ def load_model_and_tokenizer(
attn_implementation="sdpa",
torch_dtype=torch.float16,
)
model.loss_function = get_loss(train_config.task_type)(dataset_config.num_labels)

if not hasattr(model, "base_model_prefix"):
raise RuntimeError("Given huggingface model does not have 'base_model_prefix' attribute.")
Expand All @@ -131,6 +134,7 @@ def load_model_and_tokenizer(
attn_implementation="sdpa",
torch_dtype=torch.float16,
)
model.loss_function = get_loss(train_config.task_type)()

tokenizer = AutoTokenizer.from_pretrained(
train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name
Expand Down Expand Up @@ -192,7 +196,9 @@ def apply_peft(
else:
peft_config = generate_peft_config(train_config, peft_config_file, **kwargs)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

if is_rank_zero():
model.print_trainable_parameters()

return model

Expand Down Expand Up @@ -290,7 +296,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
optimizer = optim.AdamW(model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay)
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
if train_config.enable_ddp:
model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()])
model = nn.parallel.DistributedDataParallel(model, device_ids=[get_rank()])
results = train(
model,
tokenizer,
Expand All @@ -299,7 +305,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
optimizer,
scheduler,
train_config,
dist.get_rank() if train_config.enable_ddp else None,
get_rank() if train_config.enable_ddp else None,
)
if train_config.enable_ddp:
dist.destroy_process_group()
Expand Down
12 changes: 8 additions & 4 deletions QEfficient/finetune/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@ def __init__(
) -> None:
random.seed(seed)
self.batch_sampler = LengthBasedBatchSampler(
data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
data_source, batch_size=batch_size, drop_last=False, shuffle=shuffle
)
self.num_replicas = num_replicas
self.rank = rank
assert len(self.batch_sampler) % self.num_replicas == 0, (
"Length of batch samples should be divisible by number to processes in DDP."
)
self.sampler_len = len(self.batch_sampler) // self.num_replicas
self.max_length = len(self.batch_sampler)

def __iter__(self):
max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
return islice(self.batch_sampler, self.rank, self.max_length, self.num_replicas)

def __len__(self):
return len(self.batch_sampler) // self.num_replicas
return self.sampler_len
4 changes: 2 additions & 2 deletions QEfficient/finetune/dataset/alpaca_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch
from torch.utils.data import Dataset

from QEfficient.finetune.dataset.helper import IGNORE_INDEX

PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
Expand Down Expand Up @@ -42,8 +44,6 @@ def __len__(self):
return len(self.ann)

def __getitem__(self, index):
IGNORE_INDEX = -100 # The default setting

ann = self.ann[index]
if ann.get("input", "") == "":
prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
Expand Down
4 changes: 3 additions & 1 deletion QEfficient/finetune/dataset/grammar_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from datasets import load_dataset
from torch.utils.data import Dataset

from QEfficient.finetune.dataset.helper import IGNORE_INDEX


class grammar(Dataset):
def __init__(self, tokenizer, csv_name=None, context_length=None):
Expand Down Expand Up @@ -58,7 +60,7 @@ def convert_to_features(self, example_batch):
sample = {
"input_ids": prompt_ids + label_ids,
"attention_mask": [1] * len(prompt_ids + label_ids),
"labels": [-100] * len(prompt_ids) + label_ids,
"labels": [IGNORE_INDEX] * len(prompt_ids) + label_ids,
}

return sample
Expand Down
13 changes: 7 additions & 6 deletions QEfficient/finetune/dataset/gsm8k_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from datasets import Dataset, load_dataset

from QEfficient.finetune.dataset.helper import IGNORE_INDEX

default_instruction = """### Instruction: Solve the math question using a basic calculator.
Calculator can be invoked using the format: <<expression=answer>>.
"expression" can be one of the 4 arithmetic operations, and "answer" will be filled in for you.
Expand All @@ -26,9 +28,8 @@ def tokenize_and_mask(row: Dict[str, str], *, tokenizer, instruction) -> Dict[st

input_str = tokenizer.bos_token + instruction.format(**row)
ques_ids = tokenizer(input_str, add_special_tokens=False, return_attention_mask=False)["input_ids"]
ans_ids = tokenizer(row["answer"] + tokenizer.eos_token, add_special_tokens=False, return_attention_mask=False)[
"input_ids"
]
ans_str = row["answer"] + tokenizer.eos_token
ans_ids = tokenizer(ans_str, add_special_tokens=False, return_attention_mask=False)["input_ids"]
input_ids = ques_ids + ans_ids

# State machine to recognize <<expression=answer>> and mask answer
Expand All @@ -39,11 +40,11 @@ def tokenize_and_mask(row: Dict[str, str], *, tokenizer, instruction) -> Dict[st
elif mode == 1 and token in equal_tokens:
mode = 2
elif mode == 2:
ans_ids[i] = -100
ans_ids[i] = IGNORE_INDEX
if token in end_tokens:
mode = 0

labels = [-100] * len(ques_ids) + ans_ids
labels = [IGNORE_INDEX] * len(ques_ids) + ans_ids

inputs = {"input_ids": input_ids, "labels": labels}
return inputs
Expand All @@ -54,7 +55,7 @@ def pad_to_max_length(row: Dict[str, list], *, tokenizer, max_length: int) -> Di
return {
"input_ids": row["input_ids"] + [tokenizer.pad_token_id] * (max_length - length),
"attention_mask": [1] * length + [0] * (max_length - length),
"labels": row["labels"] + [-100] * (max_length - length),
"labels": row["labels"] + [IGNORE_INDEX] * (max_length - length),
}


Expand Down
8 changes: 8 additions & 0 deletions QEfficient/finetune/dataset/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

IGNORE_INDEX = -100
11 changes: 9 additions & 2 deletions QEfficient/finetune/dataset/samsum_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import datasets

from QEfficient.finetune.dataset.helper import IGNORE_INDEX


def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
dataset = datasets.load_dataset("Samsung/samsum", split=split, trust_remote_code=True)
dataset = datasets.load_dataset("knkarthick/samsum", split=split, trust_remote_code=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check if this dataset can be used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not distributing this dataset hence, it should not be a problem.


prompt = "Summarize this dialog:\n{dialog}\n---\nSummary:\n"

Expand All @@ -35,10 +37,15 @@ def tokenize_add_label(sample):
pad_to_max_length=True,
)

labels = [IGNORE_INDEX] * len(prompt) + summary
# sentence: <bos> <prompt> <summary> <eos> <pad>
# labels : -100 -100 <summary> <eos> <pad>
# Here, if pad token is not available then eos is used as pad token.

sample = {
"input_ids": prompt + summary,
"attention_mask": [1] * (len(prompt) + len(summary)),
"labels": [-100] * len(prompt) + summary,
"labels": labels,
}

return sample
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/finetune/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
19 changes: 19 additions & 0 deletions QEfficient/finetune/loss/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

from abc import ABC, abstractmethod

import torch


class BaseLoss(ABC):
def __init__(self, **kwargs):
pass

@abstractmethod
def __call__(self, logits: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor:
pass
70 changes: 70 additions & 0 deletions QEfficient/finetune/loss/generation_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

from typing import Optional

import torch
import torch.nn as nn
from transformers.loss.loss_utils import fixed_cross_entropy

from QEfficient.finetune.loss.common import BaseLoss

# Note: Below code is taken from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/loss/loss_utils.py#L45
# The original code is modified to take loss_weight into consideration.
# It will apply a boolean value to the loss for each item in the batch.
# This is helpful when we explicitly want to set loss for a particular
# sample in batch to zero. E.g. when padding of dataset is done.


class ForCausalLMLoss(BaseLoss):
def __init__(self):
pass

def __call__(
self,
logits,
labels,
vocab_size: int,
num_items_in_batch: Optional[torch.Tensor] = None,
loss_weight: Optional[torch.Tensor] = None,
ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()

if shift_labels is None:
# Shift so that tokens < n predict n
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.to(logits.device)

if loss_weight is None:
# Flatten the tokens
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
return loss
else:
total_loss = torch.tensor(0.0, device=logits.device)
bs = logits.shape[0]
for i in range(bs):
# Flatten the tokens
_logits = logits[i].view(-1, vocab_size)
_shift_labels = shift_labels[i].view(-1)
# Enable model parallelism
loss = fixed_cross_entropy(_logits, _shift_labels, ignore_index=ignore_index, **kwargs)
loss *= loss_weight[i]
total_loss += loss

if torch.sum(loss_weight) == 0:
return total_loss
else:
total_loss /= torch.sum(loss_weight)
return total_loss
21 changes: 21 additions & 0 deletions QEfficient/finetune/loss/loss_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------


from QEfficient.finetune.loss.generation_loss import ForCausalLMLoss
from QEfficient.finetune.loss.seq_cls_loss import ForSequenceClassificationLoss

loss_fn_dict = {
"seq_classification": ForSequenceClassificationLoss,
"generation": ForCausalLMLoss,
}


def get_loss(task_name: str):
if task_name not in loss_fn_dict:
raise RuntimeError(f"No loss function registered for this task name: '{task_name}'.")
return loss_fn_dict[task_name]
63 changes: 63 additions & 0 deletions QEfficient/finetune/loss/seq_cls_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

from typing import Optional

import torch
from torch.nn import BCEWithLogitsLoss, MSELoss
from transformers.loss.loss_utils import fixed_cross_entropy

from QEfficient.finetune.loss.common import BaseLoss

# Note: Below code is taken from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/loss/loss_utils.py#L92
# The original code is modified to take loss_weight into consideration.
# It will apply a boolean value to the loss for each item in the batch.
# This is helpful when we explicitly want to set loss for a particular
# sample in batch to zero. E.g. when padding of dataset is done.


class ForSequenceClassificationLoss(BaseLoss):
def __init__(self, num_labels):
self.num_labels = num_labels

def __call__(
self, pooled_logits: torch.Tensor, labels: torch.Tensor, loss_weight: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
num_labels = self.num_labels
if num_labels == 1:
problem_type = "regression"
elif num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
problem_type = "single_label_classification"
else:
problem_type = "multi_label_classification"

labels = labels.to(pooled_logits.device)
if problem_type == "regression":
loss_fct = MSELoss()
if num_labels == 1:
return loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
return loss_fct(pooled_logits, labels)
if problem_type == "single_label_classification":
if loss_weight is None:
return fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs)
else:
total_loss = torch.tensor(0.0, device=pooled_logits.device)
bs = pooled_logits.shape[0]
for i in range(bs):
total_loss += loss_weight[i] * fixed_cross_entropy(
pooled_logits[i].view(-1, num_labels), labels[i].view(-1), **kwargs
)
if torch.sum(loss_weight) == 0:
return total_loss
else:
total_loss /= torch.sum(loss_weight)
return total_loss

if problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
return loss_fct(pooled_logits, labels)
Loading
Loading