@@ -100,32 +100,8 @@ def generate_medusa_buffers(medusa_choices, device="cuda"):
100100 for i in range (len (depth_counts )):
101101 for j in range (depth_counts [i ]):
102102 cur_medusa_choice = sorted_medusa_choices [start + j ]
103- medusa_tree_indices [start + j + 1 ] = cur_medusa_choice [- 1 ] + TOPK * i + 1 ##根据每组最后一个节点和所在深度计算所在位置
103+ medusa_tree_indices [start + j + 1 ] = cur_medusa_choice [- 1 ] + TOPK * i + 1
104104 start += depth_counts [i ]
105- """
106- 逻辑上结构:
107- A (原始头预测token, 没在sorted_medusa_choices中)
108-
109- B C ... K (第一个头预测token, 预测topk个)
110-
111- banana cute ... key(第二个头预测token, 预测topk个)
112-
113- 铺平之后: A B C ... K banana cute ... key (一共1+topk*深度个=1+4*10=41个)
114-
115- A:0
116- ----
117- B:1
118- C:2
119- ...
120- k:11
121- ----
122- banana:12
123- cute:13
124- key:22
125-
126- 不是所有路径都选,节点有可能被多条路径选多次,事先设置选64个路径
127- medusa_tree_indices: 所有路径经过的节点,根据从短到长,从小到大记录下平铺后最后一个节点序号
128- """
129105
130106 # Generate position IDs for the Medusa structure
131107 medusa_position_ids = torch .zeros (medusa_len , dtype = torch .long )
@@ -138,7 +114,7 @@ def generate_medusa_buffers(medusa_choices, device="cuda"):
138114 retrieve_indices_nest = []
139115 retrieve_paths = []
140116 for i in range (len (sorted_medusa_choices )):
141- cur_medusa_choice = sorted_medusa_choices [- i - 1 ] ##倒着循环
117+ cur_medusa_choice = sorted_medusa_choices [- i - 1 ]
142118 retrieve_indice = []
143119 if cur_medusa_choice in retrieve_paths :
144120 continue
@@ -356,12 +332,6 @@ def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices, t
356332 # Unsqueeze the tree candidates for dimension consistency.
357333 # tree_candidates = tree_candidates.unsqueeze(0)
358334 return cart_candidates , tree_candidates
359- """
360- cart_candidates.shape
361- torch.Size([2, 42, 5])
362- tree_candidates.shape
363- torch.Size([2, 64])
364- """
365335
366336def update_position_id (medusa_position_ids , attention_mask , input_ids ):
367337 bs = input_ids .shape [0 ]
@@ -376,9 +346,7 @@ def update_position_id(medusa_position_ids, attention_mask, input_ids):
376346def update_attention_mask (attention_mask , tree_candidates ):
377347 bs = tree_candidates .shape [0 ]
378348 n = tree_candidates .shape [1 ]
379- # 创建一个新的张量,用于在尾部添加n个token
380349 new_tokens = torch .ones ((bs , n ), dtype = attention_mask .dtype , device = attention_mask .device )
381- # 使用torch.cat来扩增attention_mask
382350 extended_attention_mask = torch .cat ((attention_mask , new_tokens ), dim = 1 )
383351 return extended_attention_mask
384352
@@ -549,7 +517,6 @@ def evaluate_posterior(
549517
550518 if sampling == 'typical' :
551519 if fast :
552- ## logits 最后一个是新预测的,candidates第0个是原始头的输出,不用比较
553520 posterior_prob = torch .softmax (logits [:,:,:- 1 ] / temperature , dim = - 1 )
554521 candidates_prob = torch .gather (
555522 posterior_prob , dim = - 1 , index = candidates [:,:,1 :].unsqueeze (- 1 )
@@ -634,22 +601,16 @@ def gather_from_past_key_values(past_key_values_data, select_indices):
634601 layers , batch_size , head_num , _ , hidden_size = past_key_values_data .shape
635602 seqlen = select_indices .shape [1 ]
636603
637- # 初始化结果张量,用于存放选择的数据或全零填充
638604 result_data = torch .zeros (layers , batch_size , head_num , seqlen , hidden_size , device = past_key_values_data .device , dtype = past_key_values_data .dtype )
639605
640- # 扩展 select_indices 以匹配 past_key_values_data 的操作维度
641606 expanded_indices = select_indices .unsqueeze (0 ).unsqueeze (2 ).expand (layers , batch_size , head_num , seqlen )
642607
643- # 创建一个掩码,用于识别 select_indices 中的有效索引(非 -1 值)
644608 valid_indices_mask = expanded_indices != - 1
645609
646- # 修正 -1 索引值以避免 gather 时的错误,将 -1 替换为一个有效的索引(如 0),后续再通过掩码处理
647610 corrected_indices = torch .where (valid_indices_mask , expanded_indices , torch .zeros_like (expanded_indices ))
648611
649- # 使用 gather 选择数据
650612 gathered_data = torch .gather (past_key_values_data , 3 , corrected_indices .unsqueeze (- 1 ).expand (- 1 , - 1 , - 1 , - 1 , hidden_size ))
651613
652- # 利用掩码将结果中对应 -1 索引的位置替换为全零
653614 result_data = torch .where (valid_indices_mask .unsqueeze (- 1 ), gathered_data , result_data )
654615 return result_data
655616
@@ -659,9 +620,7 @@ def update_ids_new(input_ids, new_ids):
659620 return input_ids
660621
661622def update_mask (attention_mask , accept_length ):
662- # 创建一个每行都是0到max_seqlen-1的范围张量
663623 range_tensor = torch .arange (accept_length .max ().item (), device = 'cuda:0' ).expand (accept_length .shape [0 ], - 1 )
664- # 根据 accept_length 生成 mask,其中有效长度标记为1,其他为0
665624 new_attention_mask = (range_tensor < accept_length .unsqueeze (1 )).to (int )
666625 attention_mask = torch .cat ((attention_mask , new_attention_mask ), dim = - 1 )
667626 return attention_mask
@@ -769,8 +728,8 @@ def update_inference_inputs(
769728 valid_length = accept_length
770729 else :
771730 # Extract logits and medusa logits for the last accepted tokens
772- logits = logits [batch_indices , best_candidate , accept_length - 1 ] #最后一个logits
773- medusa_logits = medusa_logits [:, batch_indices , best_candidate , accept_length - 1 ] #最后一个logits
731+ logits = logits [batch_indices , best_candidate , accept_length - 1 ]
732+ medusa_logits = medusa_logits [:, batch_indices , best_candidate , accept_length - 1 ]
774733 valid_length = None
775734 # Update the new token counter
776735 new_token += max_accept_length
0 commit comments