diff --git a/openrlhf/trainer/ppo_utils/experience_maker.py b/openrlhf/trainer/ppo_utils/experience_maker.py index 798eabc52..347176707 100644 --- a/openrlhf/trainer/ppo_utils/experience_maker.py +++ b/openrlhf/trainer/ppo_utils/experience_maker.py @@ -369,11 +369,11 @@ def _generate_vllm(self, prompts: List[str], **kwargs) -> Tuple[torch.Tensor, to for output in outputs: # left padding input input_len = len(output.prompt_token_ids) - input_ids = [pad_token_id] * (max_input_len - input_len) + list(output.prompt_token_ids) + input_ids = [pad_token_id] * (max_input_len - input_len) + output.prompt_token_ids # right padding output output_len = len(output.outputs[0].token_ids) - output_ids = list(output.outputs[0].token_ids) + [pad_token_id] * (max_output_len - output_len) + output_ids = output.outputs[0].token_ids + [pad_token_id] * (max_output_len - output_len) if output_ids[output_len - 1] != eos_token_id: output_ids[min(output_len, len(output_ids) - 1)] = eos_token_id