Skip to content

[feat][merge] Support one-behind to reduce bubble time. Add profiling code. #6355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: grpo-latest-ascend
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,9 @@ applications/ColossalChat/wandb
applications/ColossalChat/model
applications/ColossalChat/eval
applications/ColossalChat/rollouts
applications/ColossalChat/*.txt
applications/ColossalChat/*.db
applications/ColossalChat/stdin
applications/ColossalChat/*.zip
applications/ColossalChat/*.prof
applications/ColossalChat/*.png
35 changes: 27 additions & 8 deletions applications/ColossalChat/coati/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,9 @@ def apply_chat_template_and_mask(
}

# Format for RL.
gt_answer = None
if "messages" in chat and "gt_answer" in chat:
gt_answer = chat["gt_answer"]
if "messages" in chat:
gt_answer = chat.get("gt_answer", None)
test_cases = chat.get("test_cases", None)
chat = [chat["messages"]]

tokens = []
Expand Down Expand Up @@ -402,12 +402,14 @@ def apply_chat_template_and_mask(
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx

if gt_answer is not None:
gt_answer = tokenizer.encode(
gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt"
)
gt_answer = gt_answer.squeeze(1)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}

elif test_cases is not None:
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"test_cases": test_cases,
}
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
Expand Down Expand Up @@ -440,3 +442,20 @@ def __getitem__(self, index: int):
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
self.tokenized_texts[index] = dict(tokens)
return self.tokenized_texts[index]


def collate_fn_grpo(batch):
input_ids = [item["input_ids"] for item in batch]
attention_mask = [item["attention_mask"] for item in batch]
labels = [item["labels"] for item in batch]
# Assume input_ids, attention_mask, labels are already of the same length,
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
input_ids = torch.stack(input_ids)
attention_mask = torch.stack(attention_mask)
labels = torch.stack(labels)
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
if "test_cases" in batch[0]:
ret["test_cases"] = [item["test_cases"] for item in batch]
if "gt_answer" in batch[0]:
ret["gt_answer"] = [item["gt_answer"] for item in batch]
return ret
168 changes: 118 additions & 50 deletions applications/ColossalChat/coati/distributed/consumer.py

Large diffs are not rendered by default.

81 changes: 17 additions & 64 deletions applications/ColossalChat/coati/distributed/grpo_consumer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from contextlib import nullcontext
from typing import Any, Dict, Optional
from typing import Any, Optional

import ray
import torch
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.distributed.utils import memory_efficient_logprob
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand Down Expand Up @@ -40,6 +38,8 @@ def __init__(
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
enable_profiling: bool = False,
n_behind: int = 0,
):
print(f"Using GRPO config: {grpo_config}")
if (
Expand All @@ -65,6 +65,8 @@ def __init__(
minibatch_size,
save_interval=save_interval,
save_dir=save_dir,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
Expand Down Expand Up @@ -119,20 +121,7 @@ def __init__(
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
)
# Initialize verifiable reward.
response_format_tags = grpo_config.get("response_format_tags", None)
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
}
self.reward_model = VerifiableReward(
reward_fns=[
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
],
tokenizer=self.tokenizer,
tags=response_format_tags,
**reward_model_kwargs,
)
grpo_config.get("response_format_tags", None)
self.global_step = 0

self.lr_scheduler = CosineAnnealingWarmupLR(
Expand Down Expand Up @@ -295,12 +284,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
)

if self.booster.plugin.stage_manager.is_last_stage():
reference_model_logits = reference_model_outputs["outputs"]["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
reference_action_log_probs = memory_efficient_logprob(
reference_model_outputs["outputs"]["logits"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
shard_config=self.plugin.shard_config,
)
else:
# Dummy reference logprobs for data iterator.
Expand All @@ -323,11 +311,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:

def _criterion(outputs, inputs):
action_logits = outputs.logits
action_log_probs = calc_action_log_probs(
action_logits / self.generate_config["temperature"],
action_log_probs = memory_efficient_logprob(
action_logits,
inputs["input_ids"],
num_action,
self.plugin.shard_config,
shard_config=self.plugin.shard_config,
)
if "reference_action_log_probs" in inputs:
per_token_kl = (
Expand Down Expand Up @@ -370,16 +358,15 @@ def _criterion(outputs, inputs):
mean_kl.append(kl)
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
else:

policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
action_log_probs = calc_action_log_probs(
action_log_probs = memory_efficient_logprob(
policy_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
shard_config=self.plugin.shard_config,
)

if self.policy_loss_fn.beta > 0:
Expand All @@ -388,11 +375,11 @@ def _criterion(outputs, inputs):
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
reference_action_log_probs = calc_action_log_probs(
reference_action_log_probs = memory_efficient_logprob(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
shard_config=self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
Expand Down Expand Up @@ -498,40 +485,6 @@ def _criterion(outputs, inputs):
else:
return None

def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
"""
Calculate the group reward for the given rollout group.

Args:
rollout_group (Dict[str, Any]):
a group of samples generated by the model from the same prompt
contain the following keys:
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
"action_mask": torch.Tensor, [num_of_generation, response_length]
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
"response_idx": int, torch.Tensor, [num_of_generation, 2]
"gt_answer": torch.Tensor, [num_of_generation, 128]
"temperature": torch.Tensor, [] (scalar)

Returns:
Dict[str, Any]: The new group data with calculated reward.
"""
reward_model_output = self.reward_model(
rollout["input_ids"],
gt_answer=rollout["gt_answer"],
response_idx=rollout["response_idx"],
)
# [num_of_generation]
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)

rollout["reward"] = reward.view((-1, 1))
rollout["format_acc"] = format_acc.view((-1, 1))
rollout["ans_acc"] = ans_acc.view((-1, 1))
return rollout

def state_dict(self):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
micro_batch_size = input_ids.size(0)
input_ids = input_ids.to(get_current_device())
attention_mask = attention_mask.to(get_current_device())
gt_answer = None
if "gt_answer" in kwargs:
gt_answer = kwargs.pop("gt_answer")
gt_answer = kwargs.pop("gt_answer", None)
test_cases = kwargs.pop("test_cases", None)
if self.num_generations > 1:
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
Expand Down Expand Up @@ -116,8 +115,9 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}

if gt_answer is not None:
# repeat gt_answer for each prompt.
data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1)
data["gt_answer"] = gt_answer
if test_cases is not None:
data["test_cases"] = test_cases
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data

Expand Down Expand Up @@ -270,11 +270,11 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
}

data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}

if "gt_answer" in kwargs:
# repeat gt_answer for each prompt.
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()}
if "gt_answer" in kwargs:
data["gt_answer"] = kwargs["gt_answer"]
if "test_cases" in kwargs:
data["test_cases"] = kwargs["test_cases"]
return data

def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
Expand Down
19 changes: 13 additions & 6 deletions applications/ColossalChat/coati/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_jsonl_size_fast(path: str) -> int:
with open(path) as f:
lines = f.readlines()
lines = [line for line in lines if line.strip()]
return len(lines) - 1
return len(lines)


def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
Expand All @@ -36,7 +36,6 @@ def launch_distributed(
train_batch_size: int,
train_minibatch_size: int,
train_dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any],
generate_config: Dict[str, Any],
train_model_config: Dict[str, Any],
Expand All @@ -57,6 +56,8 @@ def launch_distributed(
eval_generation_config: Optional[Dict[str, Any]] = None,
log_rollout_interval: int = 20,
rollout_save_dir: str = "./rollout",
enable_profiling: bool = False,
n_behind: int = 0,
):
if core_algo not in ALGO_MAP:
raise NotImplementedError(f"{core_algo} is not supported yet.")
Expand All @@ -79,6 +80,11 @@ def launch_distributed(
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
)

# Attention: Ray use complex schedualing method that consider various factors including load-balancing.
# when requesting resources, it is not guaranteed that the resource comes from a node with lower node it
# this go against the design principle of our implementation, and we need to manually force the schedualing,
# allocating the producer to nodes with lower node id and the consumer to the resouces from nodes with higher
# node id. See the reference here: https://docs.ray.io/en/latest/ray-core/scheduling/index.html#nodeaffinityschedulingstrategy
nodes = ray.nodes()
node_info = {
node["NodeID"]: {
Expand All @@ -104,7 +110,6 @@ def launch_distributed(
gpu_to_node_id.pop(0)
gpu_to_ip_address.pop(0)
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")

producer = SimpleProducer.options(
# num_cpus=1,
# num_cpus=num_proc_per_producer,
Expand All @@ -121,7 +126,6 @@ def launch_distributed(
num_episodes=num_episodes,
batch_size=inference_batch_size,
train_dataset_config=train_dataset_config,
dataloaders_config=dataloaders_config,
model_config=inference_model_config,
generate_config=generate_config,
tokenizer_config=tokenizer_config,
Expand All @@ -131,15 +135,16 @@ def launch_distributed(
consumer_plugin_config=plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
evaluation_function_type=grpo_config["reward_fn_type"],
response_format_tags=grpo_config["response_format_tags"],
grpo_config=grpo_config,
eval_save_dir=eval_save_dir,
eval_generation_config=eval_generation_config,
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
producer_procs.append(producer)
ray.get([p.setup.remote() for p in producer_procs])
Expand Down Expand Up @@ -185,6 +190,8 @@ def launch_distributed(
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
consumer_procs.append(consumer)
ray.get([p.setup.remote() for p in consumer_procs])
Expand Down
Loading