Skip to content

Commit d2428f8

Browse files
authored
[chat] refactor trainer and maker (#12)
* [chat] refactor experience maker holder * [chat] refactor model init * [chat] refactor trainer args * [chat] refactor model init * [chat] refactor trainer
1 parent bf11014 commit d2428f8

File tree

5 files changed

+117
-294
lines changed

5 files changed

+117
-294
lines changed

applications/Chat/coati/ray/example/1mmt_dummy.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import argparse
22
import os
33
import socket
4-
from copy import deepcopy
54
from functools import partial
65

76
import ray
87
import torch
9-
from coati.models.base import RewardModel
108
from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
119
from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
12-
from coati.ray.src.utils import get_actor_from_args, get_critic_from_args, get_reward_model_from_args
10+
from coati.ray.src.utils import (
11+
get_actor_from_args,
12+
get_critic_from_args,
13+
get_reward_model_from_args,
14+
get_strategy_from_args,
15+
)
1316
from transformers import AutoTokenizer, BloomTokenizerFast
1417
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
1518
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
@@ -81,39 +84,44 @@ def main(args):
8184
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
8285
tokenizer.pad_token = tokenizer.eos_token
8386

87+
def trainer_model_fn():
88+
actor = get_actor_from_args(args.model, args.pretrain).half().cuda()
89+
critic = get_critic_from_args(args.model, args.pretrain).half().cuda()
90+
return actor, critic
91+
8492
# configure Trainer
8593
trainer_refs = [
8694
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
8795
experience_maker_holder_name_list=["maker1"],
88-
strategy=args.trainer_strategy,
89-
model=args.model,
96+
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
97+
model_fn=trainer_model_fn,
9098
env_info=env_info_trainer,
91-
pretrained=args.pretrain,
92-
lora_rank=args.lora_rank,
9399
train_batch_size=args.train_batch_size,
94100
buffer_limit=16,
95-
experience_batch_size=args.experience_batch_size,
96101
max_epochs=args.max_epochs,
97-
# kwargs:
98-
max_length=512,
99-
do_sample=True,
100-
temperature=1.0,
101-
top_k=50,
102-
pad_token_id=tokenizer.pad_token_id,
103-
eos_token_id=tokenizer.eos_token_id,
104102
eval_performance=True,
105103
debug=args.debug,
106104
) for i, env_info_trainer in enumerate(env_info_trainers)
107105
]
108106

107+
def model_fn():
108+
actor = get_actor_from_args(args.model, args.pretrain).half().cuda()
109+
critic = get_critic_from_args(args.model, args.pretrain).half().cuda()
110+
reward_model = get_reward_model_from_args(args.model, args.pretrain).half().cuda()
111+
initial_model = get_actor_from_args(args.model, args.pretrain).half().cuda()
112+
return actor, critic, reward_model, initial_model
113+
109114
# configure Experience Maker
110115
experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
111116
detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)],
112-
strategy=args.maker_strategy,
117+
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
118+
model_fn=model_fn,
113119
env_info=env_info_maker,
114120
experience_batch_size=args.experience_batch_size,
115121
kl_coef=0.1,
116-
# kwargs:
122+
debug=args.debug,
123+
# sync_models_from_trainers=True,
124+
# generation kwargs:
117125
max_length=512,
118126
do_sample=True,
119127
temperature=1.0,
@@ -122,32 +130,22 @@ def main(args):
122130
eos_token_id=tokenizer.eos_token_id,
123131
eval_performance=True,
124132
use_cache=True,
125-
debug=args.debug,
126133
)
127134

128-
def init_inference_model(fn, model_name, pretrained):
129-
model = fn(model_name, pretrained)
130-
return model.half().cuda()
131-
132-
# init maker locally
133-
ray.get(
134-
experience_holder_ref.initialize_experience_maker_local.remote(
135-
initial_model_func=partial(init_inference_model, get_actor_from_args, args.model, args.pretrain),
136-
reward_model_func=partial(init_inference_model, get_reward_model_from_args, args.model, args.pretrain),
137-
actor_func=partial(init_inference_model, get_actor_from_args, args.model, args.pretrain),
138-
critic_func=partial(init_inference_model, get_critic_from_args, args.model, args.pretrain),
139-
))
140-
141135
# configure sampler
142136
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400))
143137

144138
def tokenize_fn(texts):
145-
# print(texts)
146139
input_ids = torch.stack(texts).cuda()
147-
# print(input_ids.shape)
148140
attn_mask = torch.ones_like(input_ids)
149141
return {'input_ids': input_ids, 'attention_mask': attn_mask}
150142

143+
# uncomment this function if sync_models_from_trainers is True
144+
# ray.get([
145+
# trainer_ref.sync_models_to_remote_makers.remote()
146+
# for trainer_ref in trainer_refs
147+
# ])
148+
151149
wait_tasks = []
152150

153151
for trainer_ref in trainer_refs:

applications/Chat/coati/ray/src/detached_trainer_base.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ class DetachedTrainer(ABC):
2121
Args:
2222
detached_strategy (DetachedStrategy): the strategy to use for training
2323
detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
24-
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
2524
max_epochs (int, defaults to 1): the number of epochs of training process
2625
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
2726
callbacks (List[Callback], defaults to []): the callbacks to call during training process
@@ -34,21 +33,17 @@ def __init__(self,
3433
train_batch_size: int = 8,
3534
buffer_limit: int = 0,
3635
buffer_cpu_offload: bool = True,
37-
experience_batch_size: int = 8,
3836
max_epochs: int = 1,
3937
dataloader_pin_memory: bool = True,
4038
callbacks: List[Callback] = [],
41-
debug: bool = False,
42-
**generate_kwargs) -> None:
39+
debug: bool = False) -> None:
4340
super().__init__()
4441
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size,
4542
limit=buffer_limit,
4643
cpu_offload=buffer_cpu_offload)
47-
self.experience_batch_size = experience_batch_size
4844
self.max_epochs = max_epochs
4945
self.dataloader_pin_memory = dataloader_pin_memory
5046
self.callbacks = callbacks
51-
self.generate_kwargs = generate_kwargs
5247
self.target_holder_name_list = experience_maker_holder_name_list
5348
self.target_holder_list = []
5449

@@ -61,9 +56,12 @@ def update_target_holder_list(self, experience_maker_holder_name_list):
6156
self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
6257

6358
@abstractmethod
64-
def _update_remote_makers(self):
59+
def _update_remote_makers(self, fully_update: bool = False, **kwargs):
6560
pass
6661

62+
def sync_models_to_remote_makers(self, **kwargs):
63+
self._update_remote_makers(fully_update=True, **kwargs)
64+
6765
@abstractmethod
6866
def training_step(self, experience: Experience) -> Dict[str, Any]:
6967
pass
Lines changed: 36 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from typing import Any, Callable, Dict, List, Optional
1+
from typing import Any, Callable, Dict, List, Optional, Tuple
22

33
import ray
44
import torch
55
from coati.experience_maker import Experience, NaiveExperienceMaker
66
from coati.models.base import Actor, Critic
7-
from coati.models.generation_utils import update_model_kwargs_fn
87
from coati.models.loss import PolicyLoss, ValueLoss
98
from coati.trainer.callbacks import Callback
109
from coati.trainer.callbacks.performance_evaluator import TrainerPerformaceEvaluator
@@ -54,54 +53,38 @@ class DetachedPPOTrainer(DetachedTrainer):
5453
'''
5554

5655
def __init__(
57-
self,
58-
experience_maker_holder_name_list: List[str],
59-
strategy: str,
60-
model: str,
61-
pretrained: str = None,
62-
lora_rank: int = 0,
63-
cr_model: str = None, # if not None, use below cr settings for critic
64-
cr_pretrained: str = None,
65-
cr_lora_rank: int = 0,
66-
env_info: Dict[str, str] = None,
67-
train_batch_size: int = 8,
68-
buffer_limit: int = 0,
69-
buffer_cpu_offload: bool = True,
70-
eps_clip: float = 0.2,
71-
value_clip: float = 0.4,
72-
experience_batch_size: int = 8,
73-
max_epochs: int = 10,
74-
dataloader_pin_memory: bool = True,
75-
callbacks: List[Callback] = [],
76-
eval_performance: bool = False,
77-
debug: bool = False,
78-
**generate_kwargs) -> None:
56+
self,
57+
experience_maker_holder_name_list: List[str],
58+
strategy_fn: Callable[[], Strategy],
59+
model_fn: Callable[[], Tuple[Actor, Critic]],
60+
env_info: Dict[str, str] = None,
61+
train_batch_size: int = 8,
62+
buffer_limit: int = 0,
63+
buffer_cpu_offload: bool = True,
64+
eps_clip: float = 0.2,
65+
value_clip: float = 0.4,
66+
max_epochs: int = 10,
67+
dataloader_pin_memory: bool = True,
68+
callbacks: List[Callback] = [],
69+
eval_performance: bool = False,
70+
debug: bool = False,
71+
) -> None:
7972
# set environment variables
8073
if env_info:
8174
set_dist_env(env_info=env_info)
8275
# configure strategy
83-
self.strategy = get_strategy_from_args(strategy)
76+
self.strategy = strategy_fn()
8477
# configure models, loss and optimizers
85-
if cr_model is None:
86-
cr_model = model
87-
cr_pretrained = pretrained
88-
cr_lora_rank = lora_rank
89-
9078
with self.strategy.model_init_context():
91-
self.actor = get_actor_from_args(model, pretrained, lora_rank)
92-
self.critic = get_critic_from_args(cr_model, cr_pretrained, cr_lora_rank)
79+
self.actor, self.critic = model_fn()
9380

9481
if eval_performance:
9582
actor_numel = get_model_numel(self.actor)
9683
critic_numel = get_model_numel(self.critic)
9784
evaluator = TrainerPerformaceEvaluator(actor_numel, critic_numel)
9885
callbacks = callbacks + [evaluator]
9986

100-
if strategy != 'colossalai_gemini':
101-
self.actor.to(torch.cuda.current_device()) # .to(torch.float16)
102-
self.critic.to(torch.cuda.current_device()) # .to(torch.float16)
103-
104-
if strategy.startswith('colossalai'):
87+
if isinstance(self.strategy, ColossalAIStrategy):
10588
self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7)
10689
self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7)
10790
else:
@@ -112,96 +95,49 @@ def __init__(
11295
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
11396

11497
# configure trainer
115-
generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor)
11698
self.actor_loss_fn = PolicyLoss(eps_clip)
11799
self.critic_loss_fn = ValueLoss(value_clip)
118100

119101
super().__init__(experience_maker_holder_name_list,
120102
train_batch_size=train_batch_size,
121103
buffer_limit=buffer_limit,
122104
buffer_cpu_offload=buffer_cpu_offload,
123-
experience_batch_size=experience_batch_size,
124105
max_epochs=max_epochs,
125106
dataloader_pin_memory=dataloader_pin_memory,
126107
callbacks=callbacks,
127-
debug=debug,
128-
**generate_kwargs)
129-
130-
# for remote maker initialization
131-
self._model_str = model
132-
self._cr_model_str = cr_model
133-
self._pretrained = pretrained
134-
self._cr_pretrained = cr_pretrained
108+
debug=debug)
135109

136110
@ray.method(concurrency_group="model_io")
137111
@torch.no_grad()
138-
def _update_remote_makers(self, **config):
112+
def _update_remote_makers(self, fully_update: bool = False, **config):
139113
# TODO: balance duties
140114
if is_rank_0():
141115
self.update_target_holder_list(self.target_holder_name_list)
142-
# actor:
143-
if is_rank_0():
144-
# mark start
116+
# mark start, ensure order
117+
tasks = []
145118
for target_holder in self.target_holder_list:
146-
target_holder.update_experience_maker.remote(chunk_start=True)
119+
tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update))
120+
ray.get(tasks)
147121
# sending loop
122+
tasks = []
148123
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_model(self.actor), **config):
149124
if is_rank_0():
150125
for target_holder in self.target_holder_list:
151-
target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard)
152-
if is_rank_0():
153-
# mark end
154-
for target_holder in self.target_holder_list:
155-
target_holder.update_experience_maker.remote(chunk_end=True)
156-
# critic
157-
if is_rank_0():
158-
# mark start
159-
for target_holder in self.target_holder_list:
160-
target_holder.update_experience_maker.remote(chunk_start=True)
161-
# sending loop
126+
tasks.append(
127+
target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard,
128+
fully_update=fully_update))
129+
# sending loop
162130
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), **config):
163131
if is_rank_0():
164132
for target_holder in self.target_holder_list:
165-
target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard)
133+
tasks.append(
134+
target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard,
135+
fully_update=fully_update))
136+
ray.get(tasks)
166137
if is_rank_0():
167138
# mark end
168139
for target_holder in self.target_holder_list:
169-
target_holder.update_experience_maker.remote(chunk_end=True)
170-
171-
@ray.method(concurrency_group="model_io")
172-
def initialize_remote_makers(self, **config):
173-
# TODO: balance duties
174-
if is_rank_0():
175-
self.update_target_holder_list(self.target_holder_name_list)
176-
with torch.no_grad():
177-
# actor / initial_model:
178-
# mark start
179-
for target_holder in self.target_holder_list:
180-
target_holder.initialize_experience_maker.remote(actor_model=self._model_str,
181-
actor_pretrained=self._pretrained,
182-
chunk_start=True)
183-
# sending loop
184-
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_actor(self.actor),
185-
**config):
186-
for target_holder in self.target_holder_list:
187-
target_holder.initialize_experience_maker.remote(actor_state_dict=state_dict_shard)
188-
# mark end
189-
for target_holder in self.target_holder_list:
190-
target_holder.initialize_experience_maker.remote(actor_model=self._model_str, chunk_end=True)
191-
# critic / reward_model:
192-
# mark start
193-
for target_holder in self.target_holder_list:
194-
target_holder.initialize_experience_maker.remote(critic_model=self._cr_model_str,
195-
critic_pretrained=self._cr_pretrained,
196-
chunk_start=True)
197-
# sending loop
198-
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic),
199-
**config):
200-
for target_holder in self.target_holder_list:
201-
target_holder.initialize_experience_maker.remote(critic_state_dict=state_dict_shard)
202-
# mark end
203-
for target_holder in self.target_holder_list:
204-
target_holder.initialize_experience_maker.remote(critic_model=self._cr_model_str, chunk_end=True)
140+
target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update)
205141

206142
@ray.method(concurrency_group="compute")
207143
def training_step(self, experience: Experience) -> Dict[str, float]:
@@ -273,16 +209,3 @@ def _get_model_state_dict_shard(self, model: torch.nn.Module, **config):
273209
pass
274210
for state_dict in self.strategy.get_model_state_dict_shard(model, **config):
275211
yield state_dict_to(state_dict)
276-
277-
278-
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
279-
origin_model = strategy._unwrap_actor(actor)
280-
new_kwargs = {**generate_kwargs}
281-
# use huggingface models method directly
282-
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
283-
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
284-
285-
if 'update_model_kwargs_fn' not in generate_kwargs:
286-
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
287-
288-
return new_kwargs

0 commit comments

Comments
 (0)