Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"GKDTrainer",
"GRPOConfig",
"GRPOTrainer",
"Qwen3VLGRPOTrainer",
"HfPairwiseJudge",
"KTOConfig",
"KTOTrainer",
Expand Down Expand Up @@ -150,6 +151,7 @@
GKDTrainer,
GRPOConfig,
GRPOTrainer,
Qwen3VLGRPOTrainer,
HfPairwiseJudge,
KTOConfig,
KTOTrainer,
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"gkd_config": ["GKDConfig"],
"gkd_trainer": ["GKDTrainer"],
"grpo_config": ["GRPOConfig"],
"grpo_trainer": ["GRPOTrainer"],
"grpo_trainer": ["GRPOTrainer", "Qwen3VLGRPOTrainer"],
"judges": [
"AllTrueJudge",
"BaseBinaryJudge",
Expand Down Expand Up @@ -96,7 +96,7 @@
from .gkd_config import GKDConfig
from .gkd_trainer import GKDTrainer
from .grpo_config import GRPOConfig
from .grpo_trainer import GRPOTrainer
from .grpo_trainer import GRPOTrainer, Qwen3VLGRPOTrainer
from .judges import (
AllTrueJudge,
BaseBinaryJudge,
Expand Down
146 changes: 146 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2026,3 +2026,149 @@ def _save_checkpoint(self, model, trial):
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)


class Qwen3VLGRPOTrainer(GRPOTrainer):
r"""
A GRPOTrainer specialization for Qwen3-VL (multi-modal, video-based).

Motivation
----------
- Qwen3-VL uses a Processor that expects full multi-modal conversations
(with `"type": "video"`, `"fps"`, etc.) in `apply_chat_template`.
- The generic GRPOTrainer `_generate_single_turn` assumes text-only or
image-style inputs and its batching/vLLM logic currently does not
work with Qwen3-VL video inputs.
- This subclass keeps the GRPO logic intact, but overrides
`_generate_single_turn` with a simple, per-sample generation loop
that is compatible with Qwen3-VL's Processor.

Expected `prompts` format
-------------------------
`prompts` is a list of conversations, each being a list of messages:

[
[
{
"role": "system",
"content": [
{"type": "text", "text": "..."},
...
],
},
{
"role": "user",
"content": [
{"type": "video", "video": "<path-or-bytes>", "fps": 4},
{"type": "text", "text": "Question ..."},
],
},
],
...
]

The method returns:
- prompt_ids: List[List[int]]
- completion_ids: List[List[int]]
- logprobs: None (not used in this path)
- extra_fields: dict (empty for now)
"""

def _generate_single_turn(self, prompts: list):
if self.use_vllm or self.use_transformers_paged:
raise ValueError(
"Qwen3VLVideoGRPOTrainer currently supports only the standard "
"transformers.generate path. Please set `use_vllm=False` and "
"`use_transformers_paged=False` in GRPOConfig."
)

device = self.accelerator.device

cleaned_prompts: list[list[dict]] = []
for conv in prompts:
if not isinstance(conv, list):
cleaned_prompts.append(conv)
continue

new_conv = []
for msg in conv:
role = msg.get("role", "user")
content = msg.get("content", [])
new_content = []

for chunk in content:
ctype = chunk.get("type")

if ctype == "text":
new_content.append(
{
"type": "text",
"text": chunk.get("text", ""),
}
)
elif ctype == "image":
new_content.append(
{
"type": "image",
"image": chunk.get("image"),
}
)
elif ctype == "video":
new_content.append(
{
"type": "video",
"video": chunk.get("video"),
"fps": chunk.get("fps", None),
}
)
else:
new_content.append(chunk)

new_conv.append({"role": role, "content": new_content})
cleaned_prompts.append(new_conv)

prompt_ids_list: list[list[int]] = []
completion_ids_list: list[list[int]] = []
logprobs = None
extra_fields: dict[str, list] = {}

gen_config = self.generation_config

model = self.accelerator.unwrap_model(self.model)
was_training = model.training
model.eval()

with torch.no_grad():
for conv in cleaned_prompts:
processor_inputs = self.processing_class.apply_chat_template(
conv,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
**self.chat_template_kwargs,
)

processor_inputs = {k: v.to(device) for k, v in processor_inputs.items()}

output_ids = model.generate(
**processor_inputs,
generation_config=gen_config,
# disable_compile=True # nếu người dùng không compile thì không cần, để họ cấu hình ngoài
)

input_ids = processor_inputs["input_ids"] # [1, L_prompt]
assert (
output_ids.shape[0] == 1
), "Qwen3VLVideoGRPOTrainer expects per-sample generation with batch size 1."

full_ids = output_ids[0]
prompt_len = input_ids.shape[1]

prompt_ids_list.append(input_ids[0].tolist())
completion_ids_list.append(full_ids[prompt_len:].tolist())

if was_training:
model.train()

return prompt_ids_list, completion_ids_list, logprobs, extra_fields