From af23e2db5c8bf6d89f4c4ce1321f3aa6ef9e6905 Mon Sep 17 00:00:00 2001 From: Hongwen Xin <55238914+penPenf28@users.noreply.github.com> Date: Tue, 10 Sep 2024 14:42:43 +0800 Subject: [PATCH] [Bugfix] fix multi-gpu infer (#9107) --- llm/predict/predictor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index 13eb89d906b6..d6dbbcc2c72c 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -993,7 +993,8 @@ def predict(self, input_texts: list[str], return_tokens=False): output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64").cpu() tensor_queue.put(output_tensor) - done_event.wait() + if self.tensor_parallel_rank == 0: + done_event.wait() s_time = time.time() while self.model_inputs["not_need_stop"]: self._infer(self.model_inputs) @@ -1119,7 +1120,8 @@ def predict(self, input_texts: list[str], return_tokens=False): read_res_process.start() output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64").cpu() tensor_queue.put(output_tensor) - done_event.wait() + if self.tensor_parallel_rank == 0: + done_event.wait() s_time = time.time() while self.model_inputs["not_need_stop"]: self.predictor.run(list(self.model_inputs.values()))