@@ -17,6 +17,9 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_
17
17
logger .info ("Initializing HiRadixCache" )
18
18
self .rank_in_node = rank_in_node
19
19
try :
20
+ # TODO: determine by model type && dp, tp
21
+ store_once = True # Deepseek -> True, Llama -> False
22
+ self .do_store = store_once and self .rank_in_node == 0
20
23
self .is_hi_radix_cache = True
21
24
all_buffers = self .mem_manager .kv_buffer
22
25
all_buffers = all_buffers .view (all_buffers .shape [0 ], all_buffers .shape [1 ], - 1 )
@@ -37,83 +40,111 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_
37
40
# then when the decode finishes, do syncronize to see whether this can be free
38
41
# no buffer, parallel insert inputs
39
42
def insert_disk (self , req_id , key , value ):
43
+ if not self .do_store :
44
+ return
40
45
if req_id in self .working_tasks :
41
- self .wait_till_finish (req_id )
46
+ self .abort_req_store_task (req_id )
42
47
self .working_tasks [req_id ] = self .py_cache_service .create (tokens = key , kv_page_indexer = value , mode = "w" )
43
48
logger .info (f"Created store task for req { req_id } ." )
44
49
45
- def wait_till_finish (self , req_id ):
46
- if req_id not in self .working_tasks :
50
+ def abort_req_store_task (self , req_id ):
51
+ if not self .do_store :
52
+ return
53
+ if self .working_tasks [req_id ].ready ():
54
+ logger .info (f"Calling abort for req { req_id } , but is finished." )
47
55
return
48
- starting_time = time .time ()
49
- while not self .working_tasks [req_id ].ready ():
50
- time .sleep (0.01 )
51
- logger .info (f"Waited { time .time () - starting_time } s for req { req_id } ." )
52
-
53
- # def insert(self, key, value=None):
54
- # if value is None:
55
- # value = key
56
-
57
- # assert len(key) == len(value) # and len(key) >= 1
58
- # if len(key) == 0:
59
- # return 0
60
-
61
- # # current implement is serial, TODO: make it parallel
62
- # # if no hi_cache_buffer, work with normal radix cache
63
- # if self.hi_cache_kv_buffer is not None:
64
- # do_copy = False
65
- # # and if is moving, ignore this insert request
66
- # with self.moving_lock:
67
- # if (not self.start_store_task) and self.write_task is not None:
68
- # if self.write_task.ready():
69
- # logger.info(f"HiCache of [{self.rank_in_node}]: stored len = {self.hi_cache_buffer_len}")
70
- # self.start_store_task = True # ensure ready => start new only one kvcache stores
71
- # do_copy = True
72
- # elif self.write_task is None and self.starting:
73
- # self.starting = False
74
- # self.start_store_task = True
75
- # do_copy = True
76
-
77
- # if do_copy:
78
- # # copy the key and value to the hi_cache_buffer
79
- # self.hi_cache_key_buffer[:len(key)].copy_(key)
80
- # self.hi_cache_buffer_len = len(key)
81
- # for buffer_index, index in enumerate(value):
82
- # kv_data = self.mem_manager.get_index_kv_buffer(index)
83
- # self.mem_manager.load_index_kv_buffer(self.hi_cache_kv_buffer[buffer_index], kv_data)
84
- # # create a new thread to store the buffer
85
- # self._store_buffer()
86
-
87
- # return self._insert_helper(self.root_node, key, value)
88
-
89
- # def _store_buffer(self):
90
- # logger.info(f"Storing buffer size = {self.hi_cache_buffer_len}")
91
- # assert self.hi_cache_buffer_len > 0
92
- # assert self.hi_cache_kv_buffer is not None
93
- # key = self.hi_cache_key_buffer[:self.hi_cache_buffer_len].tolist()
94
- # self.write_task = self.py_cache_service.create(
95
- # tokens=key, kv_page_indexer=self.hi_cache_kv_buffer[:self.hi_cache_buffer_len], mode="w")
96
- # with self.moving_lock:
97
- # self.start_store_task = False
56
+ logger .info (f"Aborting req { req_id } unfinished." )
57
+ self .py_cache_service .az5 (self .working_tasks [req_id ])
58
+
59
+ # TODO: finish this function to only update new ones
60
+ def _reinsert_helper (self , node : TreeNode , key , value , ans_value_list : list , update_refs = False ):
61
+ if node .is_leaf ():
62
+ self .evict_tree_set .discard (node )
63
+
64
+ if update_refs :
65
+ node .ref_counter += 1
66
+ # from 0 to 1 need update refs token num
67
+ if node .ref_counter == 1 :
68
+ self .refed_tokens_num .arr [0 ] += len (node .token_mem_index_value )
69
+
70
+ try :
71
+ if len (key ) == 0 :
72
+ return node
73
+
74
+ first_key_id = key [0 ].item ()
75
+ if first_key_id in node .children .keys ():
76
+ child : TreeNode = node .children [first_key_id ]
77
+ prefix_len = match (key , child .token_id_key )
78
+ if prefix_len == len (key ):
79
+ if child .is_leaf ():
80
+ self .evict_tree_set .discard (child )
81
+ child .update_time ()
82
+ ans_value_list .append (child .token_mem_index_value )
83
+ if child .is_leaf ():
84
+ self .evict_tree_set .add (child )
85
+ return prefix_len
86
+
87
+ elif prefix_len < len (key ) and prefix_len < len (child .token_id_key ):
88
+ if child .is_leaf ():
89
+ self .evict_tree_set .discard (child )
90
+
91
+ key = key [prefix_len :]
92
+ value = value [prefix_len :]
93
+ split_parent_node = child .split_node (prefix_len )
94
+ new_node = split_parent_node .add_and_return_new_child (key , value )
95
+ # update total token num
96
+ self .tree_total_tokens_num .arr [0 ] += len (new_node .token_mem_index_value )
97
+
98
+ if split_parent_node .is_leaf ():
99
+ self .evict_tree_set .add (split_parent_node )
100
+ if new_node .is_leaf ():
101
+ self .evict_tree_set .add (new_node )
102
+
103
+ if child .is_leaf ():
104
+ self .evict_tree_set .add (child )
105
+ return prefix_len
106
+ elif prefix_len < len (key ) and prefix_len == len (child .token_id_key ):
107
+ return prefix_len + self ._insert_helper (child , key [prefix_len :], value [prefix_len :])
108
+ else :
109
+ assert False , "can not run to here"
110
+
111
+ else :
112
+ new_node = node .add_and_return_new_child (key , value )
113
+ # update total token num
114
+ self .tree_total_tokens_num .arr [0 ] += len (new_node .token_mem_index_value )
115
+ ans_value_list .append (new_node .token_mem_index_value )
116
+ if update_refs :
117
+ new_node .ref_counter += 1
118
+ if new_node .ref_counter == 1 :
119
+ self .refed_tokens_num .arr [0 ] += len (new_node .token_mem_index_value )
120
+ if new_node .is_leaf ():
121
+ self .evict_tree_set .add (new_node )
122
+ return new_node
123
+ finally :
124
+ node .update_time ()
125
+ if node .is_leaf ():
126
+ self .evict_tree_set .add (node )
98
127
99
128
def match_prefix (self , key , update_refs = False ):
100
129
st_time = time .time ()
101
130
assert len (key ) != 0
102
131
ans_value_list = []
103
- tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = update_refs )
132
+ tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = False )
104
133
# add a parameter if get long enough (>50%)
105
134
first_query_time = time .time ()
106
135
logger .info (f"HiCache of [{ self .rank_in_node } ]: No.1 First GPU query took { first_query_time - st_time } " )
107
136
max_len = self ._query_hi_cache (key ) # x64
108
137
hi_cache_query_time = time .time ()
109
138
logger .info (f"HiCache of [{ self .rank_in_node } ]: No.2 Disk query took { hi_cache_query_time - first_query_time } " )
110
- logger .info (f"Matched { len (ans_value_list )} from gpu and { max_len } from disk." )
139
+ logger .info (f"Matched { sum ( len (s ) for s in ans_value_list )} from gpu and { max_len } from disk." )
111
140
pull_hi_cache = False
112
- if max_len > len (ans_value_list ):
141
+ if max_len > sum ( len (s ) for s in ans_value_list ):
113
142
pull_hi_cache = True
114
143
try :
115
144
self .free_radix_cache_to_get_enough_token (max_len )
116
145
except :
146
+ if update_refs :
147
+ tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = update_refs )
117
148
pull_hi_cache = False
118
149
if pull_hi_cache :
119
150
buffers = self .mem_manager .alloc (max_len )
@@ -133,7 +164,10 @@ def match_prefix(self, key, update_refs=False):
133
164
logger .info (f"HiCache of [{ self .rank_in_node } ]: No.4 Reinsert took { insert_time - hicache_pull_time } " )
134
165
ans_value_list = []
135
166
tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = update_refs )
136
- logger .info (f"HiCache of [{ self .rank_in_node } ]: No.5 Re match prefix took { time .time () - insert_time } " )
167
+ logger .info (
168
+ f"HiCache of [{ self .rank_in_node } ]: No.5 Re match prefix took { time .time () - insert_time } "
169
+ + f" matched { sum (len (s ) for s in ans_value_list )} tokens"
170
+ )
137
171
if tree_node != self .root_node :
138
172
if len (ans_value_list ) != 0 :
139
173
value = torch .concat (ans_value_list )
0 commit comments