Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sequence-Level KD #2220

Merged
merged 8 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 2 additions & 51 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,6 @@
--lora_alpha 16
"""

import logging
import multiprocessing
import os
from contextlib import nullcontext
import copy

from trl.commands.cli_utils import DPOScriptArguments, init_zero_verbose, TrlParser, fix_chat_template_if_needed, \
is_right_apply_chat
from trl.env_utils import strtobool

TRL_USE_RICH = strtobool(os.getenv("TRL_USE_RICH", "0"))

if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"

from rich.console import Console
from rich.logging import RichHandler

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand All @@ -79,9 +60,8 @@
get_peft_config,
get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)

if __name__ == "__main__":
parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig))
Expand Down Expand Up @@ -132,36 +112,7 @@
################
dataset = load_dataset(script_args.dataset_name)

# whether "apply chat template" is correct
is_right_chat = is_right_apply_chat(tokenizer, ds[args.dataset_train_split]["chosen"][0][:-1],
[ds[args.dataset_train_split]["chosen"][0][-1]])


def process(row):
prompt = row["chosen"][:-1]
chosen = [row["chosen"][-1]]
rejected = [row["rejected"][-1]]
if is_right_chat:
row["prompt"] = tokenizer.apply_chat_template(prompt, tokenize=False)
row["chosen"] = tokenizer.apply_chat_template(chosen, tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(rejected, tokenize=False)
else:
# fix if needed
fixed_prompt, fixed_chosen, fixed_rejected = fix_chat_template_if_needed(tokenizer, prompt, chosen, rejected)
row["prompt"] = fixed_prompt
row["chosen"] = fixed_chosen
row["rejected"] = fixed_rejected
return row

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=training_args.dataset_num_proc)

train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]

################
##########
# Training
################
trainer = DPOTrainer(
Expand Down
75 changes: 23 additions & 52 deletions trl/commands/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import importlib
import inspect
import logging
import os
import subprocess
import sys
from argparse import Namespace
from dataclasses import dataclass, field
from typing import List, Dict

import yaml
from transformers import HfArgumentParser


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -117,8 +118,8 @@ class DPOScriptArguments:
default=False,
metadata={
"help": "debug argument for distributed training;"
"fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
"fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
Expand Down Expand Up @@ -283,54 +284,24 @@ def set_defaults_with_config(self, **kwargs):
action.required = False


def is_right_apply_chat(tokenizer, prompt: List[Dict[str, str]], assistant_content: List[Dict[str, str]]) -> bool:
"""
Checks if the assistant's content is correctly applied to the prompt in a chat template.

Args:
tokenizer: The tokenizer.
prompt: The initial prompt message.
assistant_content: The content provided by the assistant.

Returns:
bool: True if the assistant's content is correctly applied, False otherwise.
"""
def get_git_commit_hash(package_name):
try:
test_assistant = tokenizer.apply_chat_template(assistant_content, tokenize=False)
test_prompt = tokenizer.apply_chat_template(prompt, tokenize=False)
conversation = copy.deepcopy(prompt)
conversation.append(assistant_content[0])
if tokenizer.apply_chat_template(conversation) == test_prompt + test_assistant:
return True
# Import the package to locate its path
package = importlib.import_module(package_name)
# Get the path to the package using inspect
package_path = os.path.dirname(inspect.getfile(package))

# Navigate up to the Git repository root if the package is inside a subdirectory
git_repo_path = os.path.abspath(os.path.join(package_path, ".."))
git_dir = os.path.join(git_repo_path, ".git")

if os.path.isdir(git_dir):
# Run the git command to get the current commit hash
commit_hash = (
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=git_repo_path).strip().decode("utf-8")
)
return commit_hash
else:
return False
return None
except Exception as e:
return False


def fix_chat_template_if_needed(tokenizer, prompt: List[Dict[str, str]], chosen: List[Dict[str, str]], rejected: List[Dict[str, str]]):
"""
Fixes the chat template if needed.

Args:
tokenizer: The tokenizer.
prompt: The initial prompt message.
chosen: The chosen response, a list containing a single dictionary representing the chosen message.
rejected: The rejected response, a list containing a single dictionary representing the rejected message.

Returns:
- tuple: A tuple containing the fixed prompt, fixed chosen response, and fixed rejected response.
"""
conversation_chosen = copy.deepcopy(prompt)
conversation_rejected = copy.deepcopy(prompt)
conversation_chosen.append(chosen[0])
conversation_rejected.append(rejected[0])
conversation_chosen = tokenizer.apply_chat_template(conversation_chosen, tokenize=False)
conversation_rejected = tokenizer.apply_chat_template(conversation_rejected, tokenize=False)
# find position
start_position = conversation_chosen.find(chosen[0]['content'][0])
# The following is right
fixed_prompt = conversation_chosen[:start_position]
fixed_chosen = conversation_chosen[start_position:]
fixed_rejected = conversation_rejected[start_position:]
return fixed_prompt, fixed_chosen, fixed_rejected
return f"Error: {str(e)}"
10 changes: 10 additions & 0 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional["PeftConfig"] = None,
formatting_func: Optional[Callable] = None,
seq_kd: bool = False
):
# add remove_unused_columns=False to the the dataclass args
args.remove_unused_columns = False
Expand Down Expand Up @@ -136,6 +137,7 @@ def __init__(
self.lmbda = args.lmbda
self.beta = args.beta
self.temperature = args.temperature
self.seq_kd = seq_kd

self.generation_config = GenerationConfig(
max_new_tokens=args.max_new_tokens,
Expand Down Expand Up @@ -280,6 +282,14 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
With probability `self.lmbda`, it generates new responses using the student model,
which are then used for training instead of the original inputs.
"""
if self.seq_kd:
with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
)
inputs["input_ids"] = new_input_ids
inputs["attention_mask"] = new_attention_mask
inputs["labels"] = new_labels
if random.random() <= self.lmbda:
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
Expand Down