Skip to content

Commit 3715297

Browse files
albertvillanovakashif
authored andcommitted
Pass required token_type_ids (#4148)
1 parent df80f67 commit 3715297

File tree

5 files changed

+35
-0
lines changed

5 files changed

+35
-0
lines changed

trl/experimental/gfpo/gfpo_trainer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ def _generate_and_score_completions(self, inputs):
9494
# Concatenate prompt_mask with completion_mask for logit computation
9595
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
9696
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
97+
# If token_type_ids are used, extend them with zeros for the completion part
98+
if "token_type_ids" in forward_kwargs:
99+
token_type_ids = forward_kwargs["token_type_ids"]
100+
forward_kwargs["token_type_ids"] = torch.cat(
101+
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
102+
)
97103

98104
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
99105
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
@@ -337,6 +343,8 @@ def _generate_and_score_completions(self, inputs):
337343
output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"]
338344
if "image_sizes" in forward_kwargs:
339345
output["image_sizes"] = forward_kwargs["image_sizes"]
346+
if "token_type_ids" in forward_kwargs:
347+
output["token_type_ids"] = forward_kwargs["token_type_ids"]
340348
if images is not None:
341349
output["num_images"] = num_images
342350
return output

trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ def _generate_and_score_completions(
300300
output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"]
301301
if "image_sizes" in forward_kwargs:
302302
output["image_sizes"] = forward_kwargs["image_sizes"]
303+
if "token_type_ids" in forward_kwargs:
304+
output["token_type_ids"] = forward_kwargs["token_type_ids"]
303305
if images is not None:
304306
output["images"] = images
305307
return output

trl/experimental/gspo_token/grpo_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def _compute_loss(self, model, inputs):
4141
num_images=inputs.get("num_images"),
4242
pixel_attention_mask=inputs.get("pixel_attention_mask"),
4343
image_sizes=inputs.get("image_sizes"),
44+
token_type_ids=inputs.get("token_type_ids"),
4445
)
4546

4647
if self.top_entropy_quantile < 1.0:

trl/trainer/grpo_trainer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,7 @@ def _get_per_token_logps_and_entropies(
800800
num_images=None,
801801
pixel_attention_mask=None,
802802
image_sizes=None,
803+
token_type_ids=None,
803804
) -> dict[str, Optional[torch.Tensor]]:
804805
"""Compute log-probs and (optionally) entropies for each token."""
805806
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
@@ -827,6 +828,8 @@ def _get_per_token_logps_and_entropies(
827828
model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size]
828829
if image_sizes is not None:
829830
model_inputs["image_sizes"] = image_sizes[start : start + batch_size]
831+
if token_type_ids is not None:
832+
model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size]
830833

831834
# Only add logits_to_keep if the model supports it
832835
if "logits_to_keep" in self.model_kwarg_keys:
@@ -1421,6 +1424,12 @@ def _generate_and_score_completions(
14211424
# Concatenate prompt_mask with completion_mask for logit computation
14221425
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
14231426
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
1427+
# If token_type_ids are used, extend them with zeros for the completion part
1428+
if "token_type_ids" in forward_kwargs:
1429+
token_type_ids = forward_kwargs["token_type_ids"]
1430+
forward_kwargs["token_type_ids"] = torch.cat(
1431+
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
1432+
)
14241433

14251434
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
14261435
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
@@ -1608,6 +1617,8 @@ def _generate_and_score_completions(
16081617
output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"]
16091618
if "image_sizes" in forward_kwargs:
16101619
output["image_sizes"] = forward_kwargs["image_sizes"]
1620+
if "token_type_ids" in forward_kwargs:
1621+
output["token_type_ids"] = forward_kwargs["token_type_ids"]
16111622
if images is not None:
16121623
output["num_images"] = num_images
16131624
return output
@@ -1685,6 +1696,7 @@ def _compute_loss(self, model, inputs):
16851696
num_images=inputs.get("num_images"),
16861697
pixel_attention_mask=inputs.get("pixel_attention_mask"),
16871698
image_sizes=inputs.get("image_sizes"),
1699+
token_type_ids=inputs.get("token_type_ids"),
16881700
)
16891701

16901702
if self.top_entropy_quantile < 1.0:

trl/trainer/rloo_trainer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,7 @@ def _get_per_token_logps_and_entropies(
790790
num_images=None,
791791
pixel_attention_mask=None,
792792
image_sizes=None,
793+
token_type_ids=None,
793794
) -> dict[str, Optional[torch.Tensor]]:
794795
"""Compute log-probs and (optionally) entropies for each token."""
795796
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
@@ -818,6 +819,8 @@ def _get_per_token_logps_and_entropies(
818819
model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size]
819820
if image_sizes is not None:
820821
model_inputs["image_sizes"] = image_sizes[start : start + batch_size]
822+
if token_type_ids is not None:
823+
model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size]
821824

822825
# Only add logits_to_keep if the model supports it
823826
if "logits_to_keep" in self.model_kwarg_keys:
@@ -1381,6 +1384,12 @@ def _generate_and_score_completions(
13811384
# Concatenate prompt_mask with completion_mask for logit computation
13821385
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
13831386
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
1387+
# If token_type_ids are used, extend them with zeros for the completion part
1388+
if "token_type_ids" in forward_kwargs:
1389+
token_type_ids = forward_kwargs["token_type_ids"]
1390+
forward_kwargs["token_type_ids"] = torch.cat(
1391+
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
1392+
)
13841393

13851394
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
13861395
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
@@ -1521,6 +1530,8 @@ def _generate_and_score_completions(
15211530
output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"]
15221531
if "image_sizes" in forward_kwargs:
15231532
output["image_sizes"] = forward_kwargs["image_sizes"]
1533+
if "token_type_ids" in forward_kwargs:
1534+
output["token_type_ids"] = forward_kwargs["token_type_ids"]
15241535
if images is not None:
15251536
output["num_images"] = num_images
15261537
return output
@@ -1551,6 +1562,7 @@ def _compute_loss(self, model, inputs):
15511562
num_images=inputs.get("num_images"),
15521563
pixel_attention_mask=inputs.get("pixel_attention_mask"),
15531564
image_sizes=inputs.get("image_sizes"),
1565+
token_type_ids=inputs.get("token_type_ids"),
15541566
)
15551567

15561568
logps = (per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS

0 commit comments

Comments
 (0)