1- from typing import Any , Callable , Dict , List , Optional
1+ from typing import Any , Callable , Dict , List , Optional , Tuple
22
33import ray
44import torch
55from coati .experience_maker import Experience , NaiveExperienceMaker
66from coati .models .base import Actor , Critic
7- from coati .models .generation_utils import update_model_kwargs_fn
87from coati .models .loss import PolicyLoss , ValueLoss
98from coati .trainer .callbacks import Callback
109from coati .trainer .callbacks .performance_evaluator import TrainerPerformaceEvaluator
@@ -54,54 +53,38 @@ class DetachedPPOTrainer(DetachedTrainer):
5453 '''
5554
5655 def __init__ (
57- self ,
58- experience_maker_holder_name_list : List [str ],
59- strategy : str ,
60- model : str ,
61- pretrained : str = None ,
62- lora_rank : int = 0 ,
63- cr_model : str = None , # if not None, use below cr settings for critic
64- cr_pretrained : str = None ,
65- cr_lora_rank : int = 0 ,
66- env_info : Dict [str , str ] = None ,
67- train_batch_size : int = 8 ,
68- buffer_limit : int = 0 ,
69- buffer_cpu_offload : bool = True ,
70- eps_clip : float = 0.2 ,
71- value_clip : float = 0.4 ,
72- experience_batch_size : int = 8 ,
73- max_epochs : int = 10 ,
74- dataloader_pin_memory : bool = True ,
75- callbacks : List [Callback ] = [],
76- eval_performance : bool = False ,
77- debug : bool = False ,
78- ** generate_kwargs ) -> None :
56+ self ,
57+ experience_maker_holder_name_list : List [str ],
58+ strategy_fn : Callable [[], Strategy ],
59+ model_fn : Callable [[], Tuple [Actor , Critic ]],
60+ env_info : Dict [str , str ] = None ,
61+ train_batch_size : int = 8 ,
62+ buffer_limit : int = 0 ,
63+ buffer_cpu_offload : bool = True ,
64+ eps_clip : float = 0.2 ,
65+ value_clip : float = 0.4 ,
66+ max_epochs : int = 10 ,
67+ dataloader_pin_memory : bool = True ,
68+ callbacks : List [Callback ] = [],
69+ eval_performance : bool = False ,
70+ debug : bool = False ,
71+ ) -> None :
7972 # set environment variables
8073 if env_info :
8174 set_dist_env (env_info = env_info )
8275 # configure strategy
83- self .strategy = get_strategy_from_args ( strategy )
76+ self .strategy = strategy_fn ( )
8477 # configure models, loss and optimizers
85- if cr_model is None :
86- cr_model = model
87- cr_pretrained = pretrained
88- cr_lora_rank = lora_rank
89-
9078 with self .strategy .model_init_context ():
91- self .actor = get_actor_from_args (model , pretrained , lora_rank )
92- self .critic = get_critic_from_args (cr_model , cr_pretrained , cr_lora_rank )
79+ self .actor , self .critic = model_fn ()
9380
9481 if eval_performance :
9582 actor_numel = get_model_numel (self .actor )
9683 critic_numel = get_model_numel (self .critic )
9784 evaluator = TrainerPerformaceEvaluator (actor_numel , critic_numel )
9885 callbacks = callbacks + [evaluator ]
9986
100- if strategy != 'colossalai_gemini' :
101- self .actor .to (torch .cuda .current_device ()) # .to(torch.float16)
102- self .critic .to (torch .cuda .current_device ()) # .to(torch.float16)
103-
104- if strategy .startswith ('colossalai' ):
87+ if isinstance (self .strategy , ColossalAIStrategy ):
10588 self .actor_optim = HybridAdam (self .actor .parameters (), lr = 1e-7 )
10689 self .critic_optim = HybridAdam (self .critic .parameters (), lr = 1e-7 )
10790 else :
@@ -112,96 +95,49 @@ def __init__(
11295 self .strategy .prepare ((self .actor , self .actor_optim ), (self .critic , self .critic_optim ))
11396
11497 # configure trainer
115- generate_kwargs = _set_default_generate_kwargs (self .strategy , generate_kwargs , self .actor )
11698 self .actor_loss_fn = PolicyLoss (eps_clip )
11799 self .critic_loss_fn = ValueLoss (value_clip )
118100
119101 super ().__init__ (experience_maker_holder_name_list ,
120102 train_batch_size = train_batch_size ,
121103 buffer_limit = buffer_limit ,
122104 buffer_cpu_offload = buffer_cpu_offload ,
123- experience_batch_size = experience_batch_size ,
124105 max_epochs = max_epochs ,
125106 dataloader_pin_memory = dataloader_pin_memory ,
126107 callbacks = callbacks ,
127- debug = debug ,
128- ** generate_kwargs )
129-
130- # for remote maker initialization
131- self ._model_str = model
132- self ._cr_model_str = cr_model
133- self ._pretrained = pretrained
134- self ._cr_pretrained = cr_pretrained
108+ debug = debug )
135109
136110 @ray .method (concurrency_group = "model_io" )
137111 @torch .no_grad ()
138- def _update_remote_makers (self , ** config ):
112+ def _update_remote_makers (self , fully_update : bool = False , ** config ):
139113 # TODO: balance duties
140114 if is_rank_0 ():
141115 self .update_target_holder_list (self .target_holder_name_list )
142- # actor:
143- if is_rank_0 ():
144- # mark start
116+ # mark start, ensure order
117+ tasks = []
145118 for target_holder in self .target_holder_list :
146- target_holder .update_experience_maker .remote (chunk_start = True )
119+ tasks .append (target_holder .update_experience_maker .remote (chunk_start = True , fully_update = fully_update ))
120+ ray .get (tasks )
147121 # sending loop
122+ tasks = []
148123 for state_dict_shard in self ._get_model_state_dict_shard (self .strategy ._unwrap_model (self .actor ), ** config ):
149124 if is_rank_0 ():
150125 for target_holder in self .target_holder_list :
151- target_holder .update_experience_maker .remote (new_actor_state_dict = state_dict_shard )
152- if is_rank_0 ():
153- # mark end
154- for target_holder in self .target_holder_list :
155- target_holder .update_experience_maker .remote (chunk_end = True )
156- # critic
157- if is_rank_0 ():
158- # mark start
159- for target_holder in self .target_holder_list :
160- target_holder .update_experience_maker .remote (chunk_start = True )
161- # sending loop
126+ tasks .append (
127+ target_holder .update_experience_maker .remote (new_actor_state_dict = state_dict_shard ,
128+ fully_update = fully_update ))
129+ # sending loop
162130 for state_dict_shard in self ._get_model_state_dict_shard (self .strategy ._unwrap_critic (self .critic ), ** config ):
163131 if is_rank_0 ():
164132 for target_holder in self .target_holder_list :
165- target_holder .update_experience_maker .remote (new_critic_state_dict = state_dict_shard )
133+ tasks .append (
134+ target_holder .update_experience_maker .remote (new_critic_state_dict = state_dict_shard ,
135+ fully_update = fully_update ))
136+ ray .get (tasks )
166137 if is_rank_0 ():
167138 # mark end
168139 for target_holder in self .target_holder_list :
169- target_holder .update_experience_maker .remote (chunk_end = True )
170-
171- @ray .method (concurrency_group = "model_io" )
172- def initialize_remote_makers (self , ** config ):
173- # TODO: balance duties
174- if is_rank_0 ():
175- self .update_target_holder_list (self .target_holder_name_list )
176- with torch .no_grad ():
177- # actor / initial_model:
178- # mark start
179- for target_holder in self .target_holder_list :
180- target_holder .initialize_experience_maker .remote (actor_model = self ._model_str ,
181- actor_pretrained = self ._pretrained ,
182- chunk_start = True )
183- # sending loop
184- for state_dict_shard in self ._get_model_state_dict_shard (self .strategy ._unwrap_actor (self .actor ),
185- ** config ):
186- for target_holder in self .target_holder_list :
187- target_holder .initialize_experience_maker .remote (actor_state_dict = state_dict_shard )
188- # mark end
189- for target_holder in self .target_holder_list :
190- target_holder .initialize_experience_maker .remote (actor_model = self ._model_str , chunk_end = True )
191- # critic / reward_model:
192- # mark start
193- for target_holder in self .target_holder_list :
194- target_holder .initialize_experience_maker .remote (critic_model = self ._cr_model_str ,
195- critic_pretrained = self ._cr_pretrained ,
196- chunk_start = True )
197- # sending loop
198- for state_dict_shard in self ._get_model_state_dict_shard (self .strategy ._unwrap_critic (self .critic ),
199- ** config ):
200- for target_holder in self .target_holder_list :
201- target_holder .initialize_experience_maker .remote (critic_state_dict = state_dict_shard )
202- # mark end
203- for target_holder in self .target_holder_list :
204- target_holder .initialize_experience_maker .remote (critic_model = self ._cr_model_str , chunk_end = True )
140+ target_holder .update_experience_maker .remote (chunk_end = True , fully_update = fully_update )
205141
206142 @ray .method (concurrency_group = "compute" )
207143 def training_step (self , experience : Experience ) -> Dict [str , float ]:
@@ -273,16 +209,3 @@ def _get_model_state_dict_shard(self, model: torch.nn.Module, **config):
273209 pass
274210 for state_dict in self .strategy .get_model_state_dict_shard (model , ** config ):
275211 yield state_dict_to (state_dict )
276-
277-
278- def _set_default_generate_kwargs (strategy : Strategy , generate_kwargs : dict , actor : Actor ) -> None :
279- origin_model = strategy ._unwrap_actor (actor )
280- new_kwargs = {** generate_kwargs }
281- # use huggingface models method directly
282- if 'prepare_inputs_fn' not in generate_kwargs and hasattr (origin_model , 'prepare_inputs_for_generation' ):
283- new_kwargs ['prepare_inputs_fn' ] = origin_model .prepare_inputs_for_generation
284-
285- if 'update_model_kwargs_fn' not in generate_kwargs :
286- new_kwargs ['update_model_kwargs_fn' ] = update_model_kwargs_fn
287-
288- return new_kwargs
0 commit comments