From 103e9ad8342b69ac2d7f0a4c3b8c6e99b6bd97f7 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Fri, 1 Nov 2024 15:54:54 +0800 Subject: [PATCH] fix index error when computing ppl on long-text prompt --- lmdeploy/serve/utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/lmdeploy/serve/utils.py b/lmdeploy/serve/utils.py index 4791d3c72..3a16f0a65 100644 --- a/lmdeploy/serve/utils.py +++ b/lmdeploy/serve/utils.py @@ -212,8 +212,8 @@ def get_ppl(self, input_ids: Union[List[int], logger.info(f'sorted indices: {indices}') for (start, end) in self._batch_iterator(sizes, max_input_len): logger.info(f'start: {start}, end: {end}') - _input_ids = [input_ids[indices[i]] for i in range(start, end)] if start == end: + _input_ids = input_ids[indices[start]] loss, target_count = self._get_long_text_ppl( generator=generator, input_ids=_input_ids, @@ -221,6 +221,7 @@ def get_ppl(self, input_ids: Union[List[int], losses.append(loss) target_counts.append(target_count) else: + _input_ids = [input_ids[indices[i]] for i in range(start, end)] loss, target_count = self._get_ppl( generator=generator, input_ids=_input_ids, @@ -261,24 +262,24 @@ def _batch_iterator(self, sizes, max_value): i += 1 def _get_long_text_ppl(self, generator, input_ids, max_input_len): - assert isinstance(input_ids, List) and len(input_ids) == 1 - seq_len = len(input_ids[0]) + assert all(isinstance(_, int) for _ in input_ids) + seq_len = len(input_ids) assert seq_len > max_input_len logger.info(f'get long text ppl: seq_len {seq_len}') losses = [] target_counts = [] for i in range(0, seq_len, max_input_len): - token_ids = input_ids[:, i:i + max_input_len] + token_ids = input_ids[i:i + max_input_len] step = [i] # shift token_ids by 1 to the left - target_ids = input_ids[:, i + 1:i + 1 + max_input_len] + target_ids = input_ids[i + 1:i + 1 + max_input_len] loss, target_count = self._get_ppl( generator=generator, - input_ids=token_ids, + input_ids=[token_ids], max_input_len=max_input_len, - target_ids=target_ids, + target_ids=[target_ids], steps=step, sequence_start=(i == 0), sequence_end=(i + max_input_len >= seq_len))