Skip to content

Commit

Permalink
Max_seqbatcher_number_threshold_api (deepjavalibrary#843)
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng authored Jun 21, 2023
1 parent 8e0c377 commit 45ed4c5
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 11 deletions.
33 changes: 33 additions & 0 deletions engines/python/setup/djl_python/scheduler/seq_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def add_request(self,
search_configs: List[SearchConfig] = None,
kv_cache: Union[Tuple, None] = None,
save_kv_cache_path: str = None):
# TODO: next, this will take an argument of `action`, computed by self.optimal_action.
device = input_ids.device
request_uids = request_uids.to(device)
if kv_cache:
Expand Down Expand Up @@ -72,6 +73,7 @@ def add_request(self,
save_kv_cache_path=save_kv_cache_path)

# merge
# TODO: next, an optimal action needs to be first computed, according to which the merge is done.
if not self.seq_batchers[seq_batcher_cls]:
self.seq_batchers[seq_batcher_cls].append(new_seq_batcher)
else:
Expand Down Expand Up @@ -101,6 +103,29 @@ def total_batch_size(self) -> Dict[Type[SeqBatcher], int]:
for seq_batcher in seq_batcher_list)
return batch_size

def optimal_action(self,
input_ids: torch.Tensor,
request_uids: torch.Tensor,
seq_batcher_cls: Type[SeqBatcher] = None,
search_configs: List[SearchConfig] = None,
kv_cache: Union[Tuple, None] = None,
save_kv_cache_path: str = None):
"""
Get the optimal merging action computed according to the added request and the current scheduler status.
Args:
The request information.
Return:
Optimal merging action: `Action`:
1. choose a seq_batcher to merge in
2. split a seq_batcher
3. rearrange the whole seq_batcher list
"""

# This is provided to the consumers to be used as part of the max_seq_batcher thresholding mechanism.
pass

def inference_call(self) -> Tuple[List[List[int]], List[int], List[int]]:
"""
A sweep of inference calls on all seq_batchers in the scheduler
Expand Down Expand Up @@ -154,6 +179,14 @@ def collect_results(self):

def seq_batcher_split(self, seq_batcher_cls: Type[SeqBatcher],
seq_batcher_idx: int, partitions: List[List[int]]):
"""
Split a seq_batcher in the seq_batcher_list located at seq_batcher_idx, into parts according to `partition`.
Args:
seq_batcher_cls: SeqBatcher type
seq_batcher_idx: idx in the seq_batcher_list
partitions: contains the seq_batcher_idx partitioned into lists.
"""

seq_batcher = self.seq_batchers[seq_batcher_cls].pop(seq_batcher_idx)
self.seq_batchers[seq_batcher_cls].extend(
seq_batcher.split(partitions))
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/scheduler/seq_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def __init__(self, batch: Batch, request_uids: torch.Tensor,
# This is cached output of sampler_bucket_sort result used through inferences.
self.sampler_bucket_sort_cache: Union[Tuple[Dict[str, torch.tensor],
List[SearchConfig],
List[SearchConfig]], None] = None
List[SearchConfig]],
None] = None

@classmethod
@abstractmethod
Expand Down
17 changes: 12 additions & 5 deletions engines/python/setup/djl_python/scheduler/seq_batcher_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def init_forward(
search_config_list = [
search_configs[r] for r in request_uids.view(-1).tolist()
]
next_input_ids = sampling_step_generate(last_logits, search_configs=search_config_list)
next_input_ids = sampling_step_generate(
last_logits, search_configs=search_config_list)
batch = Batch(next_input_ids=next_input_ids,
past_key_values=past_key_values)
if kv_cache is not None:
Expand Down Expand Up @@ -129,11 +130,17 @@ def forward(self) -> List[List[int]]:
# Create SeqBatcher
last_logits = logits[:, -1, :] # logits: [batch, sequence, vocab_dim]
if not self.search_config_list_cache:
self.search_config_list_cache = [self.search_configs[r] for r in self.request_uids.view(-1).tolist()]
self.search_config_list_cache = [
self.search_configs[r]
for r in self.request_uids.view(-1).tolist()
]
if not self.sampler_bucket_sort_cache:
self.sampler_bucket_sort_cache = sampler_bucket_sort(self.search_config_list_cache)
next_input_ids = sampling_step_generate(last_logits, search_configs=self.search_config_list_cache,
sampler_bucket_sort_cache=self.sampler_bucket_sort_cache)
self.sampler_bucket_sort_cache = sampler_bucket_sort(
self.search_config_list_cache)
next_input_ids = sampling_step_generate(
last_logits,
search_configs=self.search_config_list_cache,
sampler_bucket_sort_cache=self.sampler_bucket_sort_cache)
self.batch = self._get_batch_cls()(past_key_values=past_key_values,
next_input_ids=next_input_ids)
self.seq_len += 1
Expand Down
15 changes: 10 additions & 5 deletions engines/python/setup/djl_python/scheduler/step_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def contrastive_step_generate(top_k_ids: torch.Tensor,
return output_ids, select


def sampling_step_generate(logits: torch.tensor, search_configs: List[SearchConfig], sampler_bucket_sort_cache=None):
def sampling_step_generate(logits: torch.tensor,
search_configs: List[SearchConfig],
sampler_bucket_sort_cache=None):
"""
Greedy, topK, topP
Expand All @@ -74,14 +76,17 @@ def sampling_step_generate(logits: torch.tensor, search_configs: List[SearchConf
token_id: [batch, 1]
"""
collector, k_config_list, tmprtr_list_for_k, p_config_list, tmprtr_list_for_p = sampler_bucket_sort(
search_configs) if not sampler_bucket_sort_cache else sampler_bucket_sort_cache
search_configs
) if not sampler_bucket_sort_cache else sampler_bucket_sort_cache

output_ids_greedy = greedy_step_generate(logits[collector['greedy'], :])
output_ids_topk = topk_step_generate(logits[collector['topk'], :],
k_config_list, tmprtr_list_for_k)
k_config_list, tmprtr_list_for_k)
output_ids_topp = topp_step_generate(logits[collector['topk'], :],
p_config_list, tmprtr_list_for_p)
output_ids = torch.empty(len(search_configs), dtype=torch.int64, device=logits.device)
p_config_list, tmprtr_list_for_p)
output_ids = torch.empty(len(search_configs),
dtype=torch.int64,
device=logits.device)
output_ids[collector['greedy']] = output_ids_greedy.view(-1)
output_ids[collector['topk']] = output_ids_topk.view(-1)
output_ids[collector['topp']] = output_ids_topp.view(-1)
Expand Down

0 comments on commit 45ed4c5

Please sign in to comment.