Skip to content

Commit

Permalink
fix: improve find_segments via numpy diff (#2686)
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh authored Nov 18, 2024
1 parent 52e4873 commit fea62e9
Showing 1 changed file with 11 additions and 16 deletions.
27 changes: 11 additions & 16 deletions server/text_generation_server/utils/segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,25 @@
from typing import List, Tuple, Union

import torch
import numpy as np


# FIXME: this should be optimized
def find_segments(
adapter_indices: Union[torch.Tensor, List[int]]
) -> Tuple[List[int], List[int]]:
segments = [0]
segment_indices = []

if isinstance(adapter_indices, torch.Tensor):
# Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first
adapter_indices = adapter_indices.cpu().tolist()
adapter_indices = adapter_indices.cpu().numpy()
elif isinstance(adapter_indices, list):
adapter_indices = np.array(adapter_indices)

start_index = 0
for i in range(1, len(adapter_indices)):
if adapter_indices[i] != adapter_indices[i - 1]:
segments.append(i)
segment_indices.append(adapter_indices[i - 1])
start_index = i
change_mask = np.diff(adapter_indices, prepend=adapter_indices[0] - 1)
change_indices = np.nonzero(change_mask)[0]

segments = [0]
segments.extend(change_indices[1:].tolist())
segments.append(len(adapter_indices))

# Handle the last segment
if start_index < len(adapter_indices):
segments.append(len(adapter_indices))
segment_indices.append(adapter_indices[-1])
segment_indices = adapter_indices[change_indices].tolist()

return segments, segment_indices

Expand Down

0 comments on commit fea62e9

Please sign in to comment.