Skip to content

Commit

Permalink
Feature(MInference): support LLaMA-3-70B-1M and multi-gpu PP (#59)
Browse files Browse the repository at this point in the history
Co-authored-by: Yucheng Li <liyucheng09@gmail.com>
Co-authored-by: Chengruidong Zhang <chengzhang@microsoft.com>
Co-authored-by: Yuqing Yang <justin.yqyang@gmail.com>
  • Loading branch information
4 people authored Aug 1, 2024
1 parent b5b8745 commit 7a11a33
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 23 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ get_support_models()

Currently, we support the following LLMs:
- LLaMA-3.1: [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
- LLaMA-3: [gradientai/Llama-3-8B-Instruct-262k](https://huggingface.co/gradientai/Llama-3-8B-Instruct-262k), [gradientai/Llama-3-8B-Instruct-Gradient-1048k](https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k), [gradientai/Llama-3-8B-Instruct-Gradient-4194k](https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-4194k)
- LLaMA-3: [gradientai/Llama-3-8B-Instruct-262k](https://huggingface.co/gradientai/Llama-3-8B-Instruct-262k), [gradientai/Llama-3-8B-Instruct-Gradient-1048k](https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-1048k), [gradientai/Llama-3-8B-Instruct-Gradient-4194k](https://huggingface.co/gradientai/Llama-3-8B-Instruct-Gradient-4194k), [gradientai/Llama-3-70B-Instruct-Gradient-262k](https://huggingface.co/gradientai/Llama-3-70B-Instruct-Gradient-262k), [gradientai/Llama-3-70B-Instruct-Gradient-1048k](https://huggingface.co/gradientai/Llama-3-70B-Instruct-Gradient-1048k)
- GLM-4: [THUDM/glm-4-9b-chat-1m](https://huggingface.co/THUDM/glm-4-9b-chat-1m)
- Yi: [01-ai/Yi-9B-200K](https://huggingface.co/01-ai/Yi-9B-200K)
- Phi-3: [microsoft/Phi-3-mini-128k-instruct](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct)
Expand Down
2 changes: 1 addition & 1 deletion experiments/infinite_bench/run_infinitebench.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def load_model(
model_name,
config=config,
torch_dtype="auto",
device_map="cuda",
device_map="auto",
resume_download=None,
trust_remote_code=trust_remote_code,
)
Expand Down

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions minference/configs/model2path.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
"meta-llama/Meta-Llama-3.1-8B-Instruct": os.path.join(
BASE_DIR, "Llama_3.1_8B_Instruct_128k_kv_out_v32_fit_o_best_pattern.json"
),
"gradientai/Llama-3-70B-Instruct-Gradient-262k": os.path.join(
BASE_DIR, "Llama_3_70B_Instruct_262k_kv_out_v32_fit_o_best_pattern.json"
),
"gradientai/Llama-3-70B-Instruct-Gradient-1048k": os.path.join(
BASE_DIR, "Llama_3_70B_Instruct_262k_kv_out_v32_fit_o_best_pattern.json"
),
}


Expand Down
35 changes: 22 additions & 13 deletions minference/modules/inf_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def append(self, tensor: torch.Tensor):
self.append_cache()

self.data[self.length : self.length + append_l, ...].copy_(tensor)
self.data = self.data.to(tensor.device)

self.length += append_l

Expand Down Expand Up @@ -567,9 +568,14 @@ def _append(self, local_q, local_k, local_v, global_q):

# calc local result first to overlap host-device communication
attn = self.Attn(local_h_q.shape, local_h_q.dtype, local_h_q.device)
attn.append(
local_h_q, local_h_k, local_h_v, get_score=True, sliding_window=self.n_local
)
with torch.cuda.device(local_h_k.device):
attn.append(
local_h_q,
local_h_k,
local_h_v,
get_score=True,
sliding_window=self.n_local,
)

# calc topk global repr k and load cache
with torch.cuda.stream(GLOBAL_STREAM):
Expand Down Expand Up @@ -612,15 +618,16 @@ def _append(self, local_q, local_k, local_v, global_q):
torch.cuda.current_stream().wait_stream(GLOBAL_STREAM)

# calc global result
attn.append(
global_h_q,
global_h_k,
global_h_v,
end=True,
get_score=self.calc_block_score,
sliding_window=global_sliding_window,
complement_sliding_window=True,
)
with torch.cuda.device(global_h_q.device):
attn.append(
global_h_q,
global_h_k,
global_h_v,
end=True,
get_score=self.calc_block_score,
sliding_window=global_sliding_window,
complement_sliding_window=True,
)

o, score_list = attn.get_result()
loc_score = score_list[0]
Expand Down Expand Up @@ -1238,7 +1245,9 @@ def _decode(
if word.item() in end_token_ids or i == max_length:
break

input_ids = torch.cat((input_ids, word.view(1, 1)), dim=-1)
input_ids = torch.cat(
(input_ids, word.view(1, 1).to(input_ids.device)), dim=-1
)
attention_mask = torch.cat(
(
attention_mask,
Expand Down
17 changes: 9 additions & 8 deletions minference/modules/minference_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ def block_sparse(topk_ratio, slash_size=None):
attention_mask = torch.full((q_len, q_len), torch.finfo(q.dtype).min, device="cuda")
mask_cond = torch.arange(attention_mask.size(-1), device="cuda")
attention_mask.masked_fill_(mask_cond < (mask_cond + 1).view(attention_mask.size(-1), 1), 0)
attention_mask = attention_mask[None, None, :]
attention_mask = attention_mask[None, None, :].to(q.device)
SEARCH_MASK = attention_mask
else:
attention_mask = SEARCH_MASK
attention_mask = SEARCH_MASK.to(q.device)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim) + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
best_s, best_v, best_score, best_ty = 0, 0, 0, ""
Expand Down Expand Up @@ -531,7 +531,7 @@ def forward(
if self.is_search:
if os.path.exists(self.config_path):
config_list = json.load(open(self.config_path))
if self.layer_idx < len(config_list):
if self.config.num_hidden_layers == len(config_list):
assert False, f"Search completed. The config is located in {self.config_path}."
else:
config_list = []
Expand All @@ -543,7 +543,7 @@ def forward(
q = query_states[:, head, :, :].unsqueeze(1)
k = key_states[:, head, :, :].unsqueeze(1)
v = value_states[:, head, :, :].unsqueeze(1)
if self.is_search:
if self.is_search and self.layer_idx >= len(config_list):
config[head] = search_pattern(q, k, head)
if self.layer_idx >= self.starting_layer and not self.is_search:
attn_output = self.gather_last_q_vertical_slash_topk_v4(q, k, v, head)
Expand All @@ -553,13 +553,14 @@ def forward(
attn_output = gather_qkv(q, k, v, attention_mask)
output[:, head:head + 1] = attn_output
if self.is_search:
config_list.append(config)
if len(config):
config_list.append(config)
with open(self.config_path, 'w') as json_file:
json.dump(config_list, json_file)
else:
output = flash_attn_func(query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1,2), 0.0, softmax_scale=None, causal=q_len != 1).view(bsz, query_states.size(1), q_len, self.head_dim)
attn_output = output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value
Expand Down Expand Up @@ -741,7 +742,7 @@ def forward(
output[:, head:head + 1] = self.gather_last_q_vertical_slash_topk_v4(q, k, v, head)

attn_output = output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

Expand All @@ -757,7 +758,7 @@ def forward(
attn_output = gather_qkv(q, k, v, attention_mask)
output[:, head:head + 1] = attn_output
attn_output = output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value
Expand Down
6 changes: 6 additions & 0 deletions minference/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def apply_rotary_pos_emb(self, x, length, right, cos, sin):
cos = cos[:, :, right - length : right, :]
sin = sin[:, :, right - length : right, :]

if cos.device != x.device:
cos, sin = cos.to(x.device), sin.to(x.device)

return ((x.float() * cos) + (self.rotate_half(x).float() * sin)).to(dtype)

def _update_cos_sin_tables(self, x, seq_dim):
Expand Down Expand Up @@ -144,6 +147,9 @@ def apply_rotary_pos_emb_one_angle(self, x: torch.Tensor, index):
cos = cos[:, :, index - 1 : index, :]
sin = sin[:, :, index - 1 : index, :]

if cos.device != x.device:
cos, sin = cos.to(x.device), sin.to(x.device)

return ((x.float() * cos) + (self.rotate_half(x).float() * sin)).to(dtype)

def forward(
Expand Down

0 comments on commit 7a11a33

Please sign in to comment.