Skip to content

Commit 9925469

Browse files
Support chat_template_kwargs (#4350)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
1 parent 4e9ab9f commit 9925469

File tree

6 files changed

+120
-17
lines changed

6 files changed

+120
-17
lines changed

tests/test_grpo_trainer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,6 +1606,36 @@ def test_training_sequence_importance_sampling(self):
16061606
new_param = trainer.model.get_parameter(n)
16071607
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
16081608

1609+
def test_training_with_chat_template_kwargs(self):
1610+
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")
1611+
1612+
training_args = GRPOConfig(
1613+
output_dir=self.tmp_dir,
1614+
learning_rate=0.1,
1615+
per_device_train_batch_size=3,
1616+
num_generations=3,
1617+
max_completion_length=8,
1618+
report_to="none",
1619+
chat_template_kwargs={"enable_thinking": False},
1620+
)
1621+
trainer = GRPOTrainer(
1622+
model="trl-internal-testing/tiny-Qwen3ForCausalLM",
1623+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
1624+
args=training_args,
1625+
train_dataset=dataset,
1626+
)
1627+
1628+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
1629+
1630+
trainer.train()
1631+
1632+
assert trainer.state.log_history[-1]["train_loss"] is not None
1633+
1634+
# Check that the params have changed
1635+
for n, param in previous_trainable_params.items():
1636+
new_param = trainer.model.get_parameter(n)
1637+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
1638+
16091639
def test_mismatched_reward_processing_classes_length(self):
16101640
"""Test that mismatched length between reward_funcs and reward_processing_classes raises error."""
16111641
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

tests/test_rloo_trainer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,6 +1307,37 @@ def reward_func(completions, **kwargs):
13071307
new_param = trainer.model.get_parameter(n)
13081308
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
13091309

1310+
def test_training_with_chat_template_kwargs(self):
1311+
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")
1312+
1313+
training_args = RLOOConfig(
1314+
bf16=False,
1315+
output_dir=self.tmp_dir,
1316+
learning_rate=0.1,
1317+
per_device_train_batch_size=3,
1318+
num_generations=3,
1319+
max_completion_length=8,
1320+
report_to="none",
1321+
chat_template_kwargs={"enable_thinking": False},
1322+
)
1323+
trainer = RLOOTrainer(
1324+
model="trl-internal-testing/tiny-Qwen3ForCausalLM",
1325+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
1326+
args=training_args,
1327+
train_dataset=dataset,
1328+
)
1329+
1330+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
1331+
1332+
trainer.train()
1333+
1334+
assert trainer.state.log_history[-1]["train_loss"] is not None
1335+
1336+
# Check that the params have changed
1337+
for n, param in previous_trainable_params.items():
1338+
new_param = trainer.model.get_parameter(n)
1339+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
1340+
13101341
def test_mismatched_reward_processing_classes_length(self):
13111342
"""Test that mismatched length between reward_funcs and reward_processing_classes raises error."""
13121343
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

trl/trainer/grpo_config.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ class GRPOConfig(TrainingArguments):
8181
min_p (`float`, *optional*):
8282
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
8383
value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
84+
generation_kwargs (`dict[str, Any]`, *optional*):
85+
Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or
86+
`SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the
87+
generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict
88+
with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.
89+
chat_template_kwargs (`dict[str, Any]`, *optional*):
90+
Additional keyword arguments to pass to the `apply_chat_template` function when generating completions.
8491
repetition_penalty (`float`, *optional*, defaults to `1.0`):
8592
Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
8693
Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
@@ -91,11 +98,6 @@ class GRPOConfig(TrainingArguments):
9198
parameter is only effective when `use_vllm` is set to `False`.
9299
cache_implementation (`str`, *optional*):
93100
Implementation of the cache method for faster generation when `use_vllm` is set to `False`.
94-
generation_kwargs (`dict[str, Any]`, *optional*):
95-
Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or
96-
`SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the
97-
generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict
98-
with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.
99101
100102
> Parameters that control generation acceleration powered by vLLM
101103
@@ -375,6 +377,13 @@ class GRPOConfig(TrainingArguments):
375377
"conflict with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them."
376378
},
377379
)
380+
chat_template_kwargs: Optional[dict] = field(
381+
default=None,
382+
metadata={
383+
"help": "Additional keyword arguments to pass to the `apply_chat_template` function when generating "
384+
"completions."
385+
},
386+
)
378387
repetition_penalty: float = field(
379388
default=1.0,
380389
metadata={

trl/trainer/grpo_trainer.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def __init__(
377377
self.max_prompt_length = args.max_prompt_length
378378
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
379379
self.num_generations = args.num_generations # = G in the GRPO paper
380+
self.chat_template_kwargs = args.chat_template_kwargs or {}
380381
self.temperature = args.temperature
381382
self.top_p = args.top_p
382383
self.top_k = args.top_k
@@ -1066,7 +1067,10 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
10661067
if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models
10671068
if is_conversational(inputs[0]):
10681069
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
1069-
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
1070+
texts = [
1071+
apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"]
1072+
for x in messages
1073+
]
10701074
else:
10711075
texts = [p + c for p, c in zip(prompts, completions)]
10721076
reward_inputs = reward_processing_class(
@@ -1146,7 +1150,9 @@ def _generate_single_turn(self, prompts: list):
11461150
if self.rollout_func is not None:
11471151
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
11481152
ordered_set_of_prompts = [
1149-
apply_chat_template({"prompt": p}, self.processing_class)["prompt"]
1153+
apply_chat_template(
1154+
{"prompt": p}, self.processing_class, **self.chat_template_kwargs
1155+
)["prompt"]
11501156
for p in ordered_set_of_prompts
11511157
]
11521158
output = self.rollout_func(
@@ -1157,7 +1163,11 @@ def _generate_single_turn(self, prompts: list):
11571163
else:
11581164
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
11591165
# FIXME: this endpoint doesn't exist in vllm_client
1160-
output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params)
1166+
output = self.vllm_client.chat(
1167+
prompts=ordered_set_of_prompts,
1168+
**sampling_params,
1169+
chat_template_kwargs=self.chat_template_kwargs,
1170+
)
11611171
else:
11621172
output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params)
11631173
# Extract required fields and collect any extra fields for reward functions
@@ -1272,6 +1282,7 @@ def _generate_single_turn(self, prompts: list):
12721282
add_generation_prompt=True,
12731283
tokenize=True,
12741284
return_dict=True,
1285+
**self.chat_template_kwargs,
12751286
)
12761287
else:
12771288
processor_outputs = self.processing_class(text=prompts, **processor_kwargs)
@@ -1317,6 +1328,7 @@ def _generate_single_turn(self, prompts: list):
13171328
add_generation_prompt=True,
13181329
tokenize=True,
13191330
return_dict=True,
1331+
**self.chat_template_kwargs,
13201332
)
13211333
else:
13221334
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
@@ -1450,7 +1462,8 @@ def _generate_and_score_completions(
14501462
# Get forward_kwargs for models with multimodal inputs
14511463
if images is not None:
14521464
prompts_text = [
1453-
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
1465+
apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"]
1466+
for prompt in prompts
14541467
]
14551468
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
14561469
prompt_inputs = super()._prepare_inputs(prompt_inputs)

trl/trainer/rloo_config.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ class RLOOConfig(TrainingArguments):
8181
min_p (`float`, *optional*):
8282
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
8383
value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
84+
generation_kwargs (`dict[str, Any]`, *optional*):
85+
Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or
86+
`SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the
87+
generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict
88+
with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.
89+
chat_template_kwargs (`dict[str, Any]`, *optional*):
90+
Additional keyword arguments to pass to the `apply_chat_template` function when generating completions.
8491
repetition_penalty (`float`, *optional*, defaults to `1.0`):
8592
Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
8693
Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
@@ -91,11 +98,6 @@ class RLOOConfig(TrainingArguments):
9198
parameter is only effective when `use_vllm` is set to `False`.
9299
cache_implementation (`str`, *optional*):
93100
Implementation of the cache method for faster generation when `use_vllm` is set to `False`.
94-
generation_kwargs (`dict[str, Any]`, *optional*):
95-
Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or
96-
`SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the
97-
generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict
98-
with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.
99101
100102
> Parameters that control generation acceleration powered by vLLM
101103
@@ -327,6 +329,13 @@ class RLOOConfig(TrainingArguments):
327329
"conflict with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them."
328330
},
329331
)
332+
chat_template_kwargs: Optional[dict] = field(
333+
default=None,
334+
metadata={
335+
"help": "Additional keyword arguments to pass to the `apply_chat_template` function when generating "
336+
"completions."
337+
},
338+
)
330339
repetition_penalty: float = field(
331340
default=1.0,
332341
metadata={

trl/trainer/rloo_trainer.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def __init__(
361361
self.max_prompt_length = args.max_prompt_length
362362
self.max_completion_length = args.max_completion_length
363363
self.num_generations = args.num_generations
364+
self.chat_template_kwargs = args.chat_template_kwargs or {}
364365
self.temperature = args.temperature
365366
self.top_p = args.top_p
366367
self.top_k = args.top_k
@@ -927,7 +928,10 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
927928
if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models
928929
if is_conversational(inputs[0]):
929930
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
930-
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
931+
texts = [
932+
apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"]
933+
for x in messages
934+
]
931935
else:
932936
texts = [p + c for p, c in zip(prompts, completions)]
933937
reward_inputs = reward_processing_class(
@@ -1004,7 +1008,11 @@ def _generate_single_turn(self, prompts: list):
10041008
}
10051009
with profiling_context(self, "vLLM.generate"):
10061010
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
1007-
output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params)
1011+
output = self.vllm_client.chat(
1012+
prompts=ordered_set_of_prompts,
1013+
**sampling_params,
1014+
chat_template_kwargs=self.chat_template_kwargs,
1015+
)
10081016
else:
10091017
output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params)
10101018
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
@@ -1097,6 +1105,7 @@ def _generate_single_turn(self, prompts: list):
10971105
add_generation_prompt=True,
10981106
tokenize=True,
10991107
return_dict=True,
1108+
**self.chat_template_kwargs,
11001109
)
11011110
else:
11021111
processor_outputs = self.processing_class(text=prompts, **processor_kwargs)
@@ -1140,6 +1149,7 @@ def _generate_single_turn(self, prompts: list):
11401149
add_generation_prompt=True,
11411150
tokenize=True,
11421151
return_dict=True,
1152+
**self.chat_template_kwargs,
11431153
)
11441154
else:
11451155
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
@@ -1265,7 +1275,8 @@ def _generate_and_score_completions(
12651275
# Get forward_kwargs for models with multimodal inputs
12661276
if images is not None:
12671277
prompts_text = [
1268-
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
1278+
apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"]
1279+
for prompt in prompts
12691280
]
12701281
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
12711282
prompt_inputs = super()._prepare_inputs(prompt_inputs)

0 commit comments

Comments
 (0)