From 3f59821dbf04fbcb3538afcfd311e14368c11f22 Mon Sep 17 00:00:00 2001 From: Xianyu2 Date: Mon, 8 Jul 2024 02:35:34 +0000 Subject: [PATCH] support vllm 0.5.1 --- openrlhf/trainer/ppo_utils/experience_maker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openrlhf/trainer/ppo_utils/experience_maker.py b/openrlhf/trainer/ppo_utils/experience_maker.py index 347176707..798eabc52 100644 --- a/openrlhf/trainer/ppo_utils/experience_maker.py +++ b/openrlhf/trainer/ppo_utils/experience_maker.py @@ -369,11 +369,11 @@ def _generate_vllm(self, prompts: List[str], **kwargs) -> Tuple[torch.Tensor, to for output in outputs: # left padding input input_len = len(output.prompt_token_ids) - input_ids = [pad_token_id] * (max_input_len - input_len) + output.prompt_token_ids + input_ids = [pad_token_id] * (max_input_len - input_len) + list(output.prompt_token_ids) # right padding output output_len = len(output.outputs[0].token_ids) - output_ids = output.outputs[0].token_ids + [pad_token_id] * (max_output_len - output_len) + output_ids = list(output.outputs[0].token_ids) + [pad_token_id] * (max_output_len - output_len) if output_ids[output_len - 1] != eos_token_id: output_ids[min(output_len, len(output_ids) - 1)] = eos_token_id