-
Notifications
You must be signed in to change notification settings - Fork 45
[QEff Finetune] : Made fixes to training script #439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
quic-mamta
wants to merge
11
commits into
quic:main
Choose a base branch
from
quic-meetkuma:jitender_fixes
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+378
−108
Draft
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
5ea7c14
Made fixes to training script based on recent findings.
quic-meetkuma 423736d
Cleaned up the patch and added padding of the dataset with a loss_wei…
quic-meetkuma aa80625
Minor cleanup
quic-meetkuma cb0b915
Updated loss arithmatic for gradient accumulation.
quic-meetkuma 29b9339
Minor logging level change.
quic-meetkuma 6685ff4
Fixed train and eval loss and ppl mismatch. Further cleanup added as …
quic-meetkuma a0a03b4
Added custom loss implementation which will work based on loss_weight…
quic-meetkuma 21eb82d
Minor change to eval ppl calculation.
quic-meetkuma 19697d0
Fixed formating with ruff
quic-meetkuma eee5328
Brought back pad_dataset function and added some documentation for th…
quic-meetkuma 7f5f3b4
Updated results dict which is returned at the end of training.
quic-meetkuma File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- | ||
|
||
IGNORE_INDEX = -100 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
import torch | ||
|
||
|
||
class BaseLoss(ABC): | ||
def __init__(self, **kwargs): | ||
pass | ||
|
||
@abstractmethod | ||
def __call__(self, logits: torch.Tensor, labels: torch.Tensor, **kwargs) -> torch.Tensor: | ||
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
from transformers.loss.loss_utils import fixed_cross_entropy | ||
|
||
from QEfficient.finetune.loss.common import BaseLoss | ||
|
||
# Note: Below code is taken from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/loss/loss_utils.py#L45 | ||
# The original code is modified to take loss_weight into consideration. | ||
# It will apply a boolean value to the loss for each item in the batch. | ||
# This is helpful when we explicitly want to set loss for a particular | ||
# sample in batch to zero. E.g. when padding of dataset is done. | ||
|
||
|
||
class ForCausalLMLoss(BaseLoss): | ||
def __init__(self): | ||
pass | ||
|
||
def __call__( | ||
self, | ||
logits, | ||
labels, | ||
vocab_size: int, | ||
num_items_in_batch: Optional[torch.Tensor] = None, | ||
loss_weight: Optional[torch.Tensor] = None, | ||
ignore_index: int = -100, | ||
shift_labels: Optional[torch.Tensor] = None, | ||
**kwargs, | ||
) -> torch.Tensor: | ||
# Upcast to float if we need to compute the loss to avoid potential precision issues | ||
logits = logits.float() | ||
|
||
if shift_labels is None: | ||
# Shift so that tokens < n predict n | ||
labels = nn.functional.pad(labels, (0, 1), value=ignore_index) | ||
shift_labels = labels[..., 1:].contiguous() | ||
shift_labels = shift_labels.to(logits.device) | ||
|
||
if loss_weight is None: | ||
# Flatten the tokens | ||
logits = logits.view(-1, vocab_size) | ||
shift_labels = shift_labels.view(-1) | ||
# Enable model parallelism | ||
loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs) | ||
return loss | ||
else: | ||
total_loss = torch.tensor(0.0, device=logits.device) | ||
bs = logits.shape[0] | ||
for i in range(bs): | ||
# Flatten the tokens | ||
_logits = logits[i].view(-1, vocab_size) | ||
_shift_labels = shift_labels[i].view(-1) | ||
# Enable model parallelism | ||
loss = fixed_cross_entropy(_logits, _shift_labels, ignore_index=ignore_index, **kwargs) | ||
loss *= loss_weight[i] | ||
total_loss += loss | ||
|
||
if torch.sum(loss_weight) == 0: | ||
return total_loss | ||
else: | ||
total_loss /= torch.sum(loss_weight) | ||
return total_loss |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- | ||
|
||
|
||
from QEfficient.finetune.loss.generation_loss import ForCausalLMLoss | ||
from QEfficient.finetune.loss.seq_cls_loss import ForSequenceClassificationLoss | ||
|
||
loss_fn_dict = { | ||
"seq_classification": ForSequenceClassificationLoss, | ||
"generation": ForCausalLMLoss, | ||
} | ||
|
||
|
||
def get_loss(task_name: str): | ||
if task_name not in loss_fn_dict: | ||
raise RuntimeError(f"No loss function registered for this task name: '{task_name}'.") | ||
return loss_fn_dict[task_name] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
from torch.nn import BCEWithLogitsLoss, MSELoss | ||
from transformers.loss.loss_utils import fixed_cross_entropy | ||
|
||
from QEfficient.finetune.loss.common import BaseLoss | ||
|
||
# Note: Below code is taken from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/loss/loss_utils.py#L92 | ||
# The original code is modified to take loss_weight into consideration. | ||
# It will apply a boolean value to the loss for each item in the batch. | ||
# This is helpful when we explicitly want to set loss for a particular | ||
# sample in batch to zero. E.g. when padding of dataset is done. | ||
|
||
|
||
class ForSequenceClassificationLoss(BaseLoss): | ||
def __init__(self, num_labels): | ||
self.num_labels = num_labels | ||
|
||
def __call__( | ||
self, pooled_logits: torch.Tensor, labels: torch.Tensor, loss_weight: Optional[torch.Tensor] = None, **kwargs | ||
) -> torch.Tensor: | ||
num_labels = self.num_labels | ||
if num_labels == 1: | ||
problem_type = "regression" | ||
elif num_labels > 1 and (labels.dtype in (torch.long, torch.int)): | ||
problem_type = "single_label_classification" | ||
else: | ||
problem_type = "multi_label_classification" | ||
|
||
labels = labels.to(pooled_logits.device) | ||
if problem_type == "regression": | ||
loss_fct = MSELoss() | ||
if num_labels == 1: | ||
return loss_fct(pooled_logits.squeeze(), labels.squeeze()) | ||
else: | ||
return loss_fct(pooled_logits, labels) | ||
if problem_type == "single_label_classification": | ||
if loss_weight is None: | ||
return fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs) | ||
else: | ||
total_loss = torch.tensor(0.0, device=pooled_logits.device) | ||
bs = pooled_logits.shape[0] | ||
for i in range(bs): | ||
total_loss += loss_weight[i] * fixed_cross_entropy( | ||
pooled_logits[i].view(-1, num_labels), labels[i].view(-1), **kwargs | ||
) | ||
if torch.sum(loss_weight) == 0: | ||
return total_loss | ||
else: | ||
total_loss /= torch.sum(loss_weight) | ||
return total_loss | ||
|
||
if problem_type == "multi_label_classification": | ||
loss_fct = BCEWithLogitsLoss() | ||
return loss_fct(pooled_logits, labels) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check if this dataset can be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are not distributing this dataset hence, it should not be a problem.