Skip to content

Commit 684eb87

Browse files
author
Ubuntu
committed
clean
1 parent b3a7d5b commit 684eb87

File tree

2 files changed

+4
-54
lines changed

2 files changed

+4
-54
lines changed

medusa/model/utils.py

Lines changed: 4 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -100,32 +100,8 @@ def generate_medusa_buffers(medusa_choices, device="cuda"):
100100
for i in range(len(depth_counts)):
101101
for j in range(depth_counts[i]):
102102
cur_medusa_choice = sorted_medusa_choices[start + j]
103-
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1 ##根据每组最后一个节点和所在深度计算所在位置
103+
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
104104
start += depth_counts[i]
105-
"""
106-
逻辑上结构:
107-
A (原始头预测token, 没在sorted_medusa_choices中)
108-
109-
B C ... K (第一个头预测token, 预测topk个)
110-
111-
banana cute ... key(第二个头预测token, 预测topk个)
112-
113-
铺平之后: A B C ... K banana cute ... key (一共1+topk*深度个=1+4*10=41个)
114-
115-
A:0
116-
----
117-
B:1
118-
C:2
119-
...
120-
k:11
121-
----
122-
banana:12
123-
cute:13
124-
key:22
125-
126-
不是所有路径都选,节点有可能被多条路径选多次,事先设置选64个路径
127-
medusa_tree_indices: 所有路径经过的节点,根据从短到长,从小到大记录下平铺后最后一个节点序号
128-
"""
129105

130106
# Generate position IDs for the Medusa structure
131107
medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
@@ -138,7 +114,7 @@ def generate_medusa_buffers(medusa_choices, device="cuda"):
138114
retrieve_indices_nest = []
139115
retrieve_paths = []
140116
for i in range(len(sorted_medusa_choices)):
141-
cur_medusa_choice = sorted_medusa_choices[-i-1] ##倒着循环
117+
cur_medusa_choice = sorted_medusa_choices[-i-1]
142118
retrieve_indice = []
143119
if cur_medusa_choice in retrieve_paths:
144120
continue
@@ -356,12 +332,6 @@ def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices, t
356332
# Unsqueeze the tree candidates for dimension consistency.
357333
# tree_candidates = tree_candidates.unsqueeze(0)
358334
return cart_candidates, tree_candidates
359-
"""
360-
cart_candidates.shape
361-
torch.Size([2, 42, 5])
362-
tree_candidates.shape
363-
torch.Size([2, 64])
364-
"""
365335

366336
def update_position_id(medusa_position_ids, attention_mask, input_ids):
367337
bs = input_ids.shape[0]
@@ -376,9 +346,7 @@ def update_position_id(medusa_position_ids, attention_mask, input_ids):
376346
def update_attention_mask(attention_mask, tree_candidates):
377347
bs = tree_candidates.shape[0]
378348
n = tree_candidates.shape[1]
379-
# 创建一个新的张量,用于在尾部添加n个token
380349
new_tokens = torch.ones((bs, n), dtype=attention_mask.dtype, device=attention_mask.device)
381-
# 使用torch.cat来扩增attention_mask
382350
extended_attention_mask = torch.cat((attention_mask, new_tokens), dim=1)
383351
return extended_attention_mask
384352

@@ -549,7 +517,6 @@ def evaluate_posterior(
549517

550518
if sampling == 'typical':
551519
if fast:
552-
## logits 最后一个是新预测的,candidates第0个是原始头的输出,不用比较
553520
posterior_prob = torch.softmax(logits[:,:,:-1] / temperature, dim=-1)
554521
candidates_prob = torch.gather(
555522
posterior_prob, dim=-1, index=candidates[:,:,1:].unsqueeze(-1)
@@ -634,22 +601,16 @@ def gather_from_past_key_values(past_key_values_data, select_indices):
634601
layers, batch_size, head_num, _, hidden_size = past_key_values_data.shape
635602
seqlen = select_indices.shape[1]
636603

637-
# 初始化结果张量,用于存放选择的数据或全零填充
638604
result_data = torch.zeros(layers, batch_size, head_num, seqlen, hidden_size, device=past_key_values_data.device, dtype=past_key_values_data.dtype)
639605

640-
# 扩展 select_indices 以匹配 past_key_values_data 的操作维度
641606
expanded_indices = select_indices.unsqueeze(0).unsqueeze(2).expand(layers, batch_size, head_num, seqlen)
642607

643-
# 创建一个掩码,用于识别 select_indices 中的有效索引(非 -1 值)
644608
valid_indices_mask = expanded_indices != -1
645609

646-
# 修正 -1 索引值以避免 gather 时的错误,将 -1 替换为一个有效的索引(如 0),后续再通过掩码处理
647610
corrected_indices = torch.where(valid_indices_mask, expanded_indices, torch.zeros_like(expanded_indices))
648611

649-
# 使用 gather 选择数据
650612
gathered_data = torch.gather(past_key_values_data, 3, corrected_indices.unsqueeze(-1).expand(-1, -1, -1, -1, hidden_size))
651613

652-
# 利用掩码将结果中对应 -1 索引的位置替换为全零
653614
result_data = torch.where(valid_indices_mask.unsqueeze(-1), gathered_data, result_data)
654615
return result_data
655616

@@ -659,9 +620,7 @@ def update_ids_new(input_ids, new_ids):
659620
return input_ids
660621

661622
def update_mask(attention_mask, accept_length):
662-
# 创建一个每行都是0到max_seqlen-1的范围张量
663623
range_tensor = torch.arange(accept_length.max().item(), device='cuda:0').expand(accept_length.shape[0], -1)
664-
# 根据 accept_length 生成 mask,其中有效长度标记为1,其他为0
665624
new_attention_mask = (range_tensor < accept_length.unsqueeze(1)).to(int)
666625
attention_mask = torch.cat((attention_mask, new_attention_mask), dim=-1)
667626
return attention_mask
@@ -769,8 +728,8 @@ def update_inference_inputs(
769728
valid_length = accept_length
770729
else:
771730
# Extract logits and medusa logits for the last accepted tokens
772-
logits = logits[batch_indices, best_candidate, accept_length-1] #最后一个logits
773-
medusa_logits = medusa_logits[:, batch_indices, best_candidate, accept_length-1] #最后一个logits
731+
logits = logits[batch_indices, best_candidate, accept_length-1]
732+
medusa_logits = medusa_logits[:, batch_indices, best_candidate, accept_length-1]
774733
valid_length = None
775734
# Update the new token counter
776735
new_token += max_accept_length

run.sh

Lines changed: 0 additions & 9 deletions
This file was deleted.

0 commit comments

Comments
 (0)