Skip to content

Commit

Permalink
fix index error when computing ppl on long-text prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Nov 1, 2024
1 parent e034610 commit 103e9ad
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions lmdeploy/serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,16 @@ def get_ppl(self, input_ids: Union[List[int],
logger.info(f'sorted indices: {indices}')
for (start, end) in self._batch_iterator(sizes, max_input_len):
logger.info(f'start: {start}, end: {end}')
_input_ids = [input_ids[indices[i]] for i in range(start, end)]
if start == end:
_input_ids = input_ids[indices[start]]
loss, target_count = self._get_long_text_ppl(
generator=generator,
input_ids=_input_ids,
max_input_len=max_input_len)
losses.append(loss)
target_counts.append(target_count)
else:
_input_ids = [input_ids[indices[i]] for i in range(start, end)]
loss, target_count = self._get_ppl(
generator=generator,
input_ids=_input_ids,
Expand Down Expand Up @@ -261,24 +262,24 @@ def _batch_iterator(self, sizes, max_value):
i += 1

def _get_long_text_ppl(self, generator, input_ids, max_input_len):
assert isinstance(input_ids, List) and len(input_ids) == 1
seq_len = len(input_ids[0])
assert all(isinstance(_, int) for _ in input_ids)
seq_len = len(input_ids)
assert seq_len > max_input_len
logger.info(f'get long text ppl: seq_len {seq_len}')

losses = []
target_counts = []
for i in range(0, seq_len, max_input_len):
token_ids = input_ids[:, i:i + max_input_len]
token_ids = input_ids[i:i + max_input_len]
step = [i]
# shift token_ids by 1 to the left
target_ids = input_ids[:, i + 1:i + 1 + max_input_len]
target_ids = input_ids[i + 1:i + 1 + max_input_len]

loss, target_count = self._get_ppl(
generator=generator,
input_ids=token_ids,
input_ids=[token_ids],
max_input_len=max_input_len,
target_ids=target_ids,
target_ids=[target_ids],
steps=step,
sequence_start=(i == 0),
sequence_end=(i + max_input_len >= seq_len))
Expand Down

0 comments on commit 103e9ad

Please sign in to comment.