Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 31 additions & 33 deletions applications/Chat/coati/ray/example/1mmt_dummy.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import argparse
import os
import socket
from copy import deepcopy
from functools import partial

import ray
import torch
from coati.models.base import RewardModel
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
from coati.ray.src.utils import get_actor_from_args, get_critic_from_args, get_reward_model_from_args
from coati.ray.src.utils import (
get_actor_from_args,
get_critic_from_args,
get_reward_model_from_args,
get_strategy_from_args,
)
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
Expand Down Expand Up @@ -81,39 +84,44 @@ def main(args):
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

def trainer_model_fn():
actor = get_actor_from_args(args.model, args.pretrain).half().cuda()
critic = get_critic_from_args(args.model, args.pretrain).half().cuda()
return actor, critic

# configure Trainer
trainer_refs = [
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=["maker1"],
strategy=args.trainer_strategy,
model=args.model,
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
model_fn=trainer_model_fn,
env_info=env_info_trainer,
pretrained=args.pretrain,
lora_rank=args.lora_rank,
train_batch_size=args.train_batch_size,
buffer_limit=16,
experience_batch_size=args.experience_batch_size,
max_epochs=args.max_epochs,
# kwargs:
max_length=512,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
eval_performance=True,
debug=args.debug,
) for i, env_info_trainer in enumerate(env_info_trainers)
]

def model_fn():
actor = get_actor_from_args(args.model, args.pretrain).half().cuda()
critic = get_critic_from_args(args.model, args.pretrain).half().cuda()
reward_model = get_reward_model_from_args(args.model, args.pretrain).half().cuda()
initial_model = get_actor_from_args(args.model, args.pretrain).half().cuda()
return actor, critic, reward_model, initial_model

# configure Experience Maker
experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)],
strategy=args.maker_strategy,
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn,
env_info=env_info_maker,
experience_batch_size=args.experience_batch_size,
kl_coef=0.1,
# kwargs:
debug=args.debug,
# sync_models_from_trainers=True,
# generation kwargs:
max_length=512,
do_sample=True,
temperature=1.0,
Expand All @@ -122,32 +130,22 @@ def main(args):
eos_token_id=tokenizer.eos_token_id,
eval_performance=True,
use_cache=True,
debug=args.debug,
)

def init_inference_model(fn, model_name, pretrained):
model = fn(model_name, pretrained)
return model.half().cuda()

# init maker locally
ray.get(
experience_holder_ref.initialize_experience_maker_local.remote(
initial_model_func=partial(init_inference_model, get_actor_from_args, args.model, args.pretrain),
reward_model_func=partial(init_inference_model, get_reward_model_from_args, args.model, args.pretrain),
actor_func=partial(init_inference_model, get_actor_from_args, args.model, args.pretrain),
critic_func=partial(init_inference_model, get_critic_from_args, args.model, args.pretrain),
))

# configure sampler
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400))

def tokenize_fn(texts):
# print(texts)
input_ids = torch.stack(texts).cuda()
# print(input_ids.shape)
attn_mask = torch.ones_like(input_ids)
return {'input_ids': input_ids, 'attention_mask': attn_mask}

# uncomment this function if sync_models_from_trainers is True
# ray.get([
# trainer_ref.sync_models_to_remote_makers.remote()
# for trainer_ref in trainer_refs
# ])

wait_tasks = []

for trainer_ref in trainer_refs:
Expand Down
12 changes: 5 additions & 7 deletions applications/Chat/coati/ray/src/detached_trainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class DetachedTrainer(ABC):
Args:
detached_strategy (DetachedStrategy): the strategy to use for training
detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
Expand All @@ -34,21 +33,17 @@ def __init__(self,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
experience_batch_size: int = 8,
max_epochs: int = 1,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
debug: bool = False,
**generate_kwargs) -> None:
debug: bool = False) -> None:
super().__init__()
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size,
limit=buffer_limit,
cpu_offload=buffer_cpu_offload)
self.experience_batch_size = experience_batch_size
self.max_epochs = max_epochs
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
self.generate_kwargs = generate_kwargs
self.target_holder_name_list = experience_maker_holder_name_list
self.target_holder_list = []

Expand All @@ -61,9 +56,12 @@ def update_target_holder_list(self, experience_maker_holder_name_list):
self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))

@abstractmethod
def _update_remote_makers(self):
def _update_remote_makers(self, fully_update: bool = False, **kwargs):
pass

def sync_models_to_remote_makers(self, **kwargs):
self._update_remote_makers(fully_update=True, **kwargs)

@abstractmethod
def training_step(self, experience: Experience) -> Dict[str, Any]:
pass
Expand Down
149 changes: 36 additions & 113 deletions applications/Chat/coati/ray/src/detached_trainer_ppo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple

import ray
import torch
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic
from coati.models.generation_utils import update_model_kwargs_fn
from coati.models.loss import PolicyLoss, ValueLoss
from coati.trainer.callbacks import Callback
from coati.trainer.callbacks.performance_evaluator import TrainerPerformaceEvaluator
Expand Down Expand Up @@ -54,54 +53,38 @@ class DetachedPPOTrainer(DetachedTrainer):
'''

def __init__(
self,
experience_maker_holder_name_list: List[str],
strategy: str,
model: str,
pretrained: str = None,
lora_rank: int = 0,
cr_model: str = None, # if not None, use below cr settings for critic
cr_pretrained: str = None,
cr_lora_rank: int = 0,
env_info: Dict[str, str] = None,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
value_clip: float = 0.4,
experience_batch_size: int = 8,
max_epochs: int = 10,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
eval_performance: bool = False,
debug: bool = False,
**generate_kwargs) -> None:
self,
experience_maker_holder_name_list: List[str],
strategy_fn: Callable[[], Strategy],
model_fn: Callable[[], Tuple[Actor, Critic]],
env_info: Dict[str, str] = None,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
value_clip: float = 0.4,
max_epochs: int = 10,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
eval_performance: bool = False,
debug: bool = False,
) -> None:
# set environment variables
if env_info:
set_dist_env(env_info=env_info)
# configure strategy
self.strategy = get_strategy_from_args(strategy)
self.strategy = strategy_fn()
# configure models, loss and optimizers
if cr_model is None:
cr_model = model
cr_pretrained = pretrained
cr_lora_rank = lora_rank

with self.strategy.model_init_context():
self.actor = get_actor_from_args(model, pretrained, lora_rank)
self.critic = get_critic_from_args(cr_model, cr_pretrained, cr_lora_rank)
self.actor, self.critic = model_fn()

if eval_performance:
actor_numel = get_model_numel(self.actor)
critic_numel = get_model_numel(self.critic)
evaluator = TrainerPerformaceEvaluator(actor_numel, critic_numel)
callbacks = callbacks + [evaluator]

if strategy != 'colossalai_gemini':
self.actor.to(torch.cuda.current_device()) # .to(torch.float16)
self.critic.to(torch.cuda.current_device()) # .to(torch.float16)

if strategy.startswith('colossalai'):
if isinstance(self.strategy, ColossalAIStrategy):
self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7)
self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7)
else:
Expand All @@ -112,96 +95,49 @@ def __init__(
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))

# configure trainer
generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor)
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)

super().__init__(experience_maker_holder_name_list,
train_batch_size=train_batch_size,
buffer_limit=buffer_limit,
buffer_cpu_offload=buffer_cpu_offload,
experience_batch_size=experience_batch_size,
max_epochs=max_epochs,
dataloader_pin_memory=dataloader_pin_memory,
callbacks=callbacks,
debug=debug,
**generate_kwargs)

# for remote maker initialization
self._model_str = model
self._cr_model_str = cr_model
self._pretrained = pretrained
self._cr_pretrained = cr_pretrained
debug=debug)

@ray.method(concurrency_group="model_io")
@torch.no_grad()
def _update_remote_makers(self, **config):
def _update_remote_makers(self, fully_update: bool = False, **config):
# TODO: balance duties
if is_rank_0():
self.update_target_holder_list(self.target_holder_name_list)
# actor:
if is_rank_0():
# mark start
# mark start, ensure order
tasks = []
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(chunk_start=True)
tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update))
ray.get(tasks)
# sending loop
tasks = []
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_model(self.actor), **config):
if is_rank_0():
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard)
if is_rank_0():
# mark end
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(chunk_end=True)
# critic
if is_rank_0():
# mark start
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(chunk_start=True)
# sending loop
tasks.append(
target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard,
fully_update=fully_update))
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), **config):
if is_rank_0():
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard)
tasks.append(
target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard,
fully_update=fully_update))
ray.get(tasks)
if is_rank_0():
# mark end
for target_holder in self.target_holder_list:
target_holder.update_experience_maker.remote(chunk_end=True)

@ray.method(concurrency_group="model_io")
def initialize_remote_makers(self, **config):
# TODO: balance duties
if is_rank_0():
self.update_target_holder_list(self.target_holder_name_list)
with torch.no_grad():
# actor / initial_model:
# mark start
for target_holder in self.target_holder_list:
target_holder.initialize_experience_maker.remote(actor_model=self._model_str,
actor_pretrained=self._pretrained,
chunk_start=True)
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_actor(self.actor),
**config):
for target_holder in self.target_holder_list:
target_holder.initialize_experience_maker.remote(actor_state_dict=state_dict_shard)
# mark end
for target_holder in self.target_holder_list:
target_holder.initialize_experience_maker.remote(actor_model=self._model_str, chunk_end=True)
# critic / reward_model:
# mark start
for target_holder in self.target_holder_list:
target_holder.initialize_experience_maker.remote(critic_model=self._cr_model_str,
critic_pretrained=self._cr_pretrained,
chunk_start=True)
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic),
**config):
for target_holder in self.target_holder_list:
target_holder.initialize_experience_maker.remote(critic_state_dict=state_dict_shard)
# mark end
for target_holder in self.target_holder_list:
target_holder.initialize_experience_maker.remote(critic_model=self._cr_model_str, chunk_end=True)
target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update)

@ray.method(concurrency_group="compute")
def training_step(self, experience: Experience) -> Dict[str, float]:
Expand Down Expand Up @@ -273,16 +209,3 @@ def _get_model_state_dict_shard(self, model: torch.nn.Module, **config):
pass
for state_dict in self.strategy.get_model_state_dict_shard(model, **config):
yield state_dict_to(state_dict)


def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy._unwrap_actor(actor)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation

if 'update_model_kwargs_fn' not in generate_kwargs:
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn

return new_kwargs
Loading