@@ -377,6 +377,7 @@ def __init__(
377377 self .max_prompt_length = args .max_prompt_length
378378 self .max_completion_length = args .max_completion_length # = |o_i| in the GRPO paper
379379 self .num_generations = args .num_generations # = G in the GRPO paper
380+ self .chat_template_kwargs = args .chat_template_kwargs or {}
380381 self .temperature = args .temperature
381382 self .top_p = args .top_p
382383 self .top_k = args .top_k
@@ -1066,7 +1067,10 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
10661067 if isinstance (reward_func , nn .Module ): # Module (no PretrainedModel) for compat with compiled models
10671068 if is_conversational (inputs [0 ]):
10681069 messages = [{"messages" : p + c } for p , c in zip (prompts , completions )]
1069- texts = [apply_chat_template (x , reward_processing_class )["text" ] for x in messages ]
1070+ texts = [
1071+ apply_chat_template (x , reward_processing_class , ** self .chat_template_kwargs )["text" ]
1072+ for x in messages
1073+ ]
10701074 else :
10711075 texts = [p + c for p , c in zip (prompts , completions )]
10721076 reward_inputs = reward_processing_class (
@@ -1146,7 +1150,9 @@ def _generate_single_turn(self, prompts: list):
11461150 if self .rollout_func is not None :
11471151 if is_conversational ({"prompt" : ordered_set_of_prompts [0 ]}):
11481152 ordered_set_of_prompts = [
1149- apply_chat_template ({"prompt" : p }, self .processing_class )["prompt" ]
1153+ apply_chat_template (
1154+ {"prompt" : p }, self .processing_class , ** self .chat_template_kwargs
1155+ )["prompt" ]
11501156 for p in ordered_set_of_prompts
11511157 ]
11521158 output = self .rollout_func (
@@ -1157,7 +1163,11 @@ def _generate_single_turn(self, prompts: list):
11571163 else :
11581164 if is_conversational ({"prompt" : ordered_set_of_prompts [0 ]}):
11591165 # FIXME: this endpoint doesn't exist in vllm_client
1160- output = self .vllm_client .chat (prompts = ordered_set_of_prompts , ** sampling_params )
1166+ output = self .vllm_client .chat (
1167+ prompts = ordered_set_of_prompts ,
1168+ ** sampling_params ,
1169+ chat_template_kwargs = self .chat_template_kwargs ,
1170+ )
11611171 else :
11621172 output = self .vllm_client .generate (prompts = ordered_set_of_prompts , ** sampling_params )
11631173 # Extract required fields and collect any extra fields for reward functions
@@ -1272,6 +1282,7 @@ def _generate_single_turn(self, prompts: list):
12721282 add_generation_prompt = True ,
12731283 tokenize = True ,
12741284 return_dict = True ,
1285+ ** self .chat_template_kwargs ,
12751286 )
12761287 else :
12771288 processor_outputs = self .processing_class (text = prompts , ** processor_kwargs )
@@ -1317,6 +1328,7 @@ def _generate_single_turn(self, prompts: list):
13171328 add_generation_prompt = True ,
13181329 tokenize = True ,
13191330 return_dict = True ,
1331+ ** self .chat_template_kwargs ,
13201332 )
13211333 else :
13221334 generate_inputs = self .processing_class (text = prompts , ** processor_kwargs )
@@ -1450,7 +1462,8 @@ def _generate_and_score_completions(
14501462 # Get forward_kwargs for models with multimodal inputs
14511463 if images is not None :
14521464 prompts_text = [
1453- apply_chat_template ({"prompt" : prompt }, self .processing_class )["prompt" ] for prompt in prompts
1465+ apply_chat_template ({"prompt" : prompt }, self .processing_class , ** self .chat_template_kwargs )["prompt" ]
1466+ for prompt in prompts
14541467 ]
14551468 prompt_inputs = self .processing_class (images = images , text = prompts_text , padding = True , return_tensors = "pt" )
14561469 prompt_inputs = super ()._prepare_inputs (prompt_inputs )
0 commit comments