Skip to content

Commit

Permalink
[Fix] bug_fix_for_empty_tensor_input (deepjavalibrary#928)
Browse files Browse the repository at this point in the history
* bug_fix_empty_input
---------
Co-authored-by: KexinFeng <fenkexin@amazon.com>
  • Loading branch information
KexinFeng authored Jul 11, 2023
1 parent 4c7be3e commit 666162a
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 48 deletions.
85 changes: 47 additions & 38 deletions engines/python/setup/djl_python/scheduler/seq_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,45 +66,44 @@ def add_request(self,
if search_configs:
for idx, search_config in enumerate(search_configs):
if search_config.use_lru_kv_cache:
prompt_ids_tensor = kv_cache_prompt_ids[
request_uids[idx].item()]
key = tuple(prompt_ids_tensor.flatten().tolist())
if not key:
request_uid = request_uids[idx].item()
if request_uid not in kv_cache_prompt_ids:
raise Exception(
f"request_uids = {request_uids[idx]}: search_config says use_kv_cache_prompt, "
f"request_uids = {request_uid}: search_config says use_kv_cache_prompt, "
f"but the prompt_ids is not provided.")
prompt_ids_tensor = kv_cache_prompt_ids[request_uid]
key = tuple(prompt_ids_tensor.flatten().tolist())
# lru operations
if key not in self.lru_kv_cache:
if len(self.lru_kv_cache) + 1 > self.lru_max_size:
# If cache size exceeds the maximum, remove by FIFO order
self.lru_kv_cache.popitem(last=False)
kv_cache_tuple = compute_kv_cache(
input_ids=prompt_ids_tensor,
lm_block=self.lm_block,
search_configs=[search_config])
kv_cache_new = []
for k, v in kv_cache_tuple:
k_new = k.cpu()
v_new = v.cpu()
kv_cache_new.append((k_new, v_new))
self.lru_kv_cache[key] = tuple(kv_cache_new)
self.lru_kv_cache.move_to_end(key)

# _add_request
self._add_request(input_ids[idx].view(1, -1),
request_uids[idx].view(1, -1),
search_algorithm,
[search_config],
kv_cache=kv_cache_tuple)
else:
# lru operations
if key not in self.lru_kv_cache:
if len(self.lru_kv_cache) + 1 > self.lru_max_size:
# If cache size exceeds the maximum, remove by FIFO order
self.lru_kv_cache.popitem(last=False)
kv_cache_tuple = compute_kv_cache(
input_ids=prompt_ids_tensor,
lm_block=self.lm_block,
search_configs=[search_config])
kv_cache_new = []
for k, v in kv_cache_tuple:
k_new = k.cpu()
v_new = v.cpu()
kv_cache_new.append((k_new, v_new))
self.lru_kv_cache[key] = tuple(kv_cache_new)
self.lru_kv_cache.move_to_end(key)

# _add_request
self._add_request(input_ids[idx].view(1, -1),
request_uids[idx].view(1, -1),
search_algorithm,
[search_config],
kv_cache=kv_cache_tuple)
else:
# _add_request
self._add_request(input_ids[idx].view(1, -1),
request_uids[idx].view(1, -1),
search_algorithm,
[search_config],
kv_cache=self.lru_kv_cache[key])
self.lru_kv_cache.move_to_end(key)
# _add_request
self._add_request(input_ids[idx].view(1, -1),
request_uids[idx].view(1, -1),
search_algorithm,
[search_config],
kv_cache=self.lru_kv_cache[key])
self.lru_kv_cache.move_to_end(key)
else:
index_not_use_prompt.append(idx)
search_configs_not_use_prompt.append(search_config)
Expand All @@ -116,8 +115,7 @@ def add_request(self,
index_not_use_prompt = torch.tensor(index_not_use_prompt)
self._add_request(input_ids[index_not_use_prompt],
request_uids[index_not_use_prompt],
search_algorithm,
search_configs_not_use_prompt,
search_algorithm, search_configs_not_use_prompt,
kv_cache)

def _add_request(self,
Expand Down Expand Up @@ -146,6 +144,17 @@ def _add_request(self,

seq_batcher_cls = self.default_seq_batcher_cls if seq_batcher_cls is None else seq_batcher_cls

# Corner case: input_ids are empty. Pad them.
if input_ids.numel() == 0:
batch_size = input_ids.shape[0]
input_ids = torch.zeros(batch_size,
1,
dtype=torch.int64,
device=input_ids.device)
for i in range(batch_size):
input_ids[i, 0] = self.default_search_configs[
request_uids[i].item()].pad_token_id

# Prefill
new_seq_batcher, output_ids = seq_batcher_cls.init_forward(
input_ids=input_ids,
Expand Down
27 changes: 17 additions & 10 deletions engines/python/setup/djl_python/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,16 +455,23 @@ def test_lru_kv_cache(self):
prompt_ids_dict = {1: prompt_ids, 2: prompt_ids}

# Load a kv_cache from file and test merging a shorter sequence
input_ids = tokenizer(
[r"When your legs don't work", r"'t remember", r""],
return_tensors='pt',
padding=True).input_ids
request_ids = torch.tensor([[0], [1], [2]])
search_configs = [
SearchConfig(),
SearchConfig(use_lru_kv_cache=True),
SearchConfig(use_lru_kv_cache=True)
]
input_ids = tokenizer([r"When your legs don't work", r"'t remember"],
return_tensors='pt',
padding=True).input_ids
request_ids = torch.tensor([[0], [1]])
search_configs = [SearchConfig(), SearchConfig(use_lru_kv_cache=True)]

# Load a kv_cache file to simulate a fixed reusable prefix which is pre-calculated
scheduler.add_request(input_ids,
request_ids,
search_configs=search_configs,
kv_cache_prompt_ids=prompt_ids_dict)

# Test empty input_ids
input_ids = tokenizer([r""], return_tensors='pt',
padding=True).input_ids.view(1, -1)
request_ids = torch.tensor([[2]])
search_configs = [SearchConfig(use_lru_kv_cache=True)]

# Load a kv_cache file to simulate a fixed reusable prefix which is pre-calculated
scheduler.add_request(input_ids,
Expand Down

0 comments on commit 666162a

Please sign in to comment.