Skip to content

Commit 70d6167

Browse files
committed
Support only one GPU store partially with bug
1 parent 589d81a commit 70d6167

File tree

5 files changed

+110
-68
lines changed

5 files changed

+110
-68
lines changed

lightllm/server/router/dynamic_prompt/hiradix_cache.py

Lines changed: 91 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_
1717
logger.info("Initializing HiRadixCache")
1818
self.rank_in_node = rank_in_node
1919
try:
20+
# TODO: determine by model type && dp, tp
21+
store_once = True # Deepseek -> True, Llama -> False
22+
self.do_store = store_once and self.rank_in_node == 0
2023
self.is_hi_radix_cache = True
2124
all_buffers = self.mem_manager.kv_buffer
2225
all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1)
@@ -37,83 +40,111 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_
3740
# then when the decode finishes, do syncronize to see whether this can be free
3841
# no buffer, parallel insert inputs
3942
def insert_disk(self, req_id, key, value):
43+
if not self.do_store:
44+
return
4045
if req_id in self.working_tasks:
41-
self.wait_till_finish(req_id)
46+
self.abort_req_store_task(req_id)
4247
self.working_tasks[req_id] = self.py_cache_service.create(tokens=key, kv_page_indexer=value, mode="w")
4348
logger.info(f"Created store task for req {req_id}.")
4449

45-
def wait_till_finish(self, req_id):
46-
if req_id not in self.working_tasks:
50+
def abort_req_store_task(self, req_id):
51+
if not self.do_store:
52+
return
53+
if self.working_tasks[req_id].ready():
54+
logger.info(f"Calling abort for req {req_id}, but is finished.")
4755
return
48-
starting_time = time.time()
49-
while not self.working_tasks[req_id].ready():
50-
time.sleep(0.01)
51-
logger.info(f"Waited {time.time() - starting_time}s for req {req_id}.")
52-
53-
# def insert(self, key, value=None):
54-
# if value is None:
55-
# value = key
56-
57-
# assert len(key) == len(value) # and len(key) >= 1
58-
# if len(key) == 0:
59-
# return 0
60-
61-
# # current implement is serial, TODO: make it parallel
62-
# # if no hi_cache_buffer, work with normal radix cache
63-
# if self.hi_cache_kv_buffer is not None:
64-
# do_copy = False
65-
# # and if is moving, ignore this insert request
66-
# with self.moving_lock:
67-
# if (not self.start_store_task) and self.write_task is not None:
68-
# if self.write_task.ready():
69-
# logger.info(f"HiCache of [{self.rank_in_node}]: stored len = {self.hi_cache_buffer_len}")
70-
# self.start_store_task = True # ensure ready => start new only one kvcache stores
71-
# do_copy = True
72-
# elif self.write_task is None and self.starting:
73-
# self.starting = False
74-
# self.start_store_task = True
75-
# do_copy = True
76-
77-
# if do_copy:
78-
# # copy the key and value to the hi_cache_buffer
79-
# self.hi_cache_key_buffer[:len(key)].copy_(key)
80-
# self.hi_cache_buffer_len = len(key)
81-
# for buffer_index, index in enumerate(value):
82-
# kv_data = self.mem_manager.get_index_kv_buffer(index)
83-
# self.mem_manager.load_index_kv_buffer(self.hi_cache_kv_buffer[buffer_index], kv_data)
84-
# # create a new thread to store the buffer
85-
# self._store_buffer()
86-
87-
# return self._insert_helper(self.root_node, key, value)
88-
89-
# def _store_buffer(self):
90-
# logger.info(f"Storing buffer size = {self.hi_cache_buffer_len}")
91-
# assert self.hi_cache_buffer_len > 0
92-
# assert self.hi_cache_kv_buffer is not None
93-
# key = self.hi_cache_key_buffer[:self.hi_cache_buffer_len].tolist()
94-
# self.write_task = self.py_cache_service.create(
95-
# tokens=key, kv_page_indexer=self.hi_cache_kv_buffer[:self.hi_cache_buffer_len], mode="w")
96-
# with self.moving_lock:
97-
# self.start_store_task = False
56+
logger.info(f"Aborting req {req_id} unfinished.")
57+
self.py_cache_service.az5(self.working_tasks[req_id])
58+
59+
# TODO: finish this function to only update new ones
60+
def _reinsert_helper(self, node: TreeNode, key, value, ans_value_list: list, update_refs=False):
61+
if node.is_leaf():
62+
self.evict_tree_set.discard(node)
63+
64+
if update_refs:
65+
node.ref_counter += 1
66+
# from 0 to 1 need update refs token num
67+
if node.ref_counter == 1:
68+
self.refed_tokens_num.arr[0] += len(node.token_mem_index_value)
69+
70+
try:
71+
if len(key) == 0:
72+
return node
73+
74+
first_key_id = key[0].item()
75+
if first_key_id in node.children.keys():
76+
child: TreeNode = node.children[first_key_id]
77+
prefix_len = match(key, child.token_id_key)
78+
if prefix_len == len(key):
79+
if child.is_leaf():
80+
self.evict_tree_set.discard(child)
81+
child.update_time()
82+
ans_value_list.append(child.token_mem_index_value)
83+
if child.is_leaf():
84+
self.evict_tree_set.add(child)
85+
return prefix_len
86+
87+
elif prefix_len < len(key) and prefix_len < len(child.token_id_key):
88+
if child.is_leaf():
89+
self.evict_tree_set.discard(child)
90+
91+
key = key[prefix_len:]
92+
value = value[prefix_len:]
93+
split_parent_node = child.split_node(prefix_len)
94+
new_node = split_parent_node.add_and_return_new_child(key, value)
95+
# update total token num
96+
self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value)
97+
98+
if split_parent_node.is_leaf():
99+
self.evict_tree_set.add(split_parent_node)
100+
if new_node.is_leaf():
101+
self.evict_tree_set.add(new_node)
102+
103+
if child.is_leaf():
104+
self.evict_tree_set.add(child)
105+
return prefix_len
106+
elif prefix_len < len(key) and prefix_len == len(child.token_id_key):
107+
return prefix_len + self._insert_helper(child, key[prefix_len:], value[prefix_len:])
108+
else:
109+
assert False, "can not run to here"
110+
111+
else:
112+
new_node = node.add_and_return_new_child(key, value)
113+
# update total token num
114+
self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value)
115+
ans_value_list.append(new_node.token_mem_index_value)
116+
if update_refs:
117+
new_node.ref_counter += 1
118+
if new_node.ref_counter == 1:
119+
self.refed_tokens_num.arr[0] += len(new_node.token_mem_index_value)
120+
if new_node.is_leaf():
121+
self.evict_tree_set.add(new_node)
122+
return new_node
123+
finally:
124+
node.update_time()
125+
if node.is_leaf():
126+
self.evict_tree_set.add(node)
98127

99128
def match_prefix(self, key, update_refs=False):
100129
st_time = time.time()
101130
assert len(key) != 0
102131
ans_value_list = []
103-
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
132+
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False)
104133
# add a parameter if get long enough (>50%)
105134
first_query_time = time.time()
106135
logger.info(f"HiCache of [{self.rank_in_node}]: No.1 First GPU query took {first_query_time - st_time}")
107136
max_len = self._query_hi_cache(key) # x64
108137
hi_cache_query_time = time.time()
109138
logger.info(f"HiCache of [{self.rank_in_node}]: No.2 Disk query took {hi_cache_query_time - first_query_time}")
110-
logger.info(f"Matched {len(ans_value_list)} from gpu and {max_len} from disk.")
139+
logger.info(f"Matched {sum(len(s) for s in ans_value_list)} from gpu and {max_len} from disk.")
111140
pull_hi_cache = False
112-
if max_len > len(ans_value_list):
141+
if max_len > sum(len(s) for s in ans_value_list):
113142
pull_hi_cache = True
114143
try:
115144
self.free_radix_cache_to_get_enough_token(max_len)
116145
except:
146+
if update_refs:
147+
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
117148
pull_hi_cache = False
118149
if pull_hi_cache:
119150
buffers = self.mem_manager.alloc(max_len)
@@ -133,7 +164,10 @@ def match_prefix(self, key, update_refs=False):
133164
logger.info(f"HiCache of [{self.rank_in_node}]: No.4 Reinsert took {insert_time - hicache_pull_time}")
134165
ans_value_list = []
135166
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
136-
logger.info(f"HiCache of [{self.rank_in_node}]: No.5 Re match prefix took {time.time() - insert_time}")
167+
logger.info(
168+
f"HiCache of [{self.rank_in_node}]: No.5 Re match prefix took {time.time() - insert_time}"
169+
+ f" matched {sum(len(s) for s in ans_value_list)} tokens"
170+
)
137171
if tree_node != self.root_node:
138172
if len(ans_value_list) != 0:
139173
value = torch.concat(ans_value_list)

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis
110110
req.shared_kv_node = None
111111

112112
if self.radix_cache.is_hi_radix_cache:
113-
self.radix_cache.wait_till_finish(req.req_id)
113+
self.radix_cache.abort_req_store_task(req.req_id)
114114

115115
def _save_promptcache_kvbuffer(self):
116116
"""

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,6 @@ def decode(self):
165165
"""This method can be overridden in subclasses."""
166166
raise NotImplementedError()
167167

168-
def store_hicache_after_prefill(self, run_reqs):
169-
if self.use_hi_dynamic_prompt_cache and self.radix_cache is not None:
170-
for req in run_reqs:
171-
key = torch.tensor(req.get_input_token_ids()[0 : req.cur_kv_len], dtype=torch.int64, device="cpu")
172-
value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu()
173-
self.radix_cache.insert_disk(req.req_id, key, value)
174-
175168
def pause_reqs(self, req_ids):
176169
if self.dp_size_in_node != 1:
177170
req_ids = [req_id for req_id in req_ids if req_id in g_infer_context.requests_mapping]
@@ -350,6 +343,23 @@ def _overlap_req_init_and_filter(
350343

351344
return
352345

346+
def _overlap_store_prefill_reqs(self, run_reqs: List[InferReq]):
347+
if run_reqs:
348+
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
349+
if self.use_hi_dynamic_prompt_cache and self.radix_cache is not None:
350+
for req in run_reqs:
351+
if req.cur_output_len > 1:
352+
continue
353+
key = torch.tensor(
354+
req.get_input_token_ids()[0 : req.cur_kv_len], dtype=torch.int64, device="cpu"
355+
)
356+
value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu()
357+
self.radix_cache.insert_disk(req.req_id, key, value)
358+
359+
torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream())
360+
361+
return
362+
353363
# 一些可以复用的通用功能函数
354364
def _post_init_reqs(self, uninit_reqs: List[InferReq]):
355365
"""

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def decode(self):
4343
self._overlap_req_init_and_filter(
4444
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
4545
)
46+
self._overlap_store_prefill_reqs(run_reqs=run_reqs)
4647
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
4748
next_token_ids = next_token_ids.detach().cpu().numpy()
4849
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
@@ -59,7 +60,6 @@ def decode(self):
5960
prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
6061
)
6162
logits = self.model.forward(**kwargs)
62-
self.store_hicache_after_prefill(run_reqs)
6363
self._overlap_req_init_and_filter(
6464
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
6565
)

lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ def decode(self):
3535
)
3636
logits = self.model.forward(**kwargs)
3737

38-
self.store_hicache_after_prefill(run_reqs)
39-
4038
self._overlap_req_init_and_filter(
4139
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
4240
)

0 commit comments

Comments
 (0)