Skip to content

add check for rope and tuning qwen3 on H200 #880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 4}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 4}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"1": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 16, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 1}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 1}, "256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 16, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 16, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 16, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, "4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "8192": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 8, "num_stages": 4}}
44 changes: 20 additions & 24 deletions lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ class LlamaFlashInferStateExtraInfo:
def __init__(self, model):
tp_world_size = get_dp_world_size()
self.tp_q_head_num = model.config["num_attention_heads"] // tp_world_size
self.tp_kv_head_num = model.config["num_key_value_heads"] // tp_world_size
self.head_dim = model.config["hidden_size"] // model.config["num_attention_heads"]
self.tp_kv_head_num = max(model.config["num_key_value_heads"] // tp_world_size, 1)
head_dim = model.config["hidden_size"] // model.config["num_attention_heads"]
self.head_dim = model.config.get("head_dim", head_dim)
self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(get_current_device_id())
self.max_seq_length = model.max_seq_length
self.kv_indices_buffer = [
Expand Down Expand Up @@ -104,33 +105,29 @@ def _init_custom(self):
"""
模型特殊的一些初始化
"""
if self.config.get("use_rope_yarn", False) or (
self.config.get("rope_scaling", None) is not None
and self.config.get("rope_scaling", {}).get("type", "base") == "yarn"
):
rope_scaling = self.config.get("rope_scaling", None)
if rope_scaling is None:
self._init_to_get_rotary()
return

if "rope_type" in rope_scaling:
scaling_type = rope_scaling["rope_type"]
elif "type" in rope_scaling:
scaling_type = rope_scaling["type"]
else:
raise ValueError(f"Unknown RoPE scaling format {rope_scaling}")
if scaling_type == "yarn":
self._init_to_get_yarn_rotary()
elif self.config.get("use_dynamic_ntk", False) or (
self.config.get("rope_scaling", None) is not None
and self.config.get("rope_scaling", {}).get("type", "base") == "dynamic"
):
elif scaling_type == "dynamic":
self._init_to_get_dynamic_ntk_rotary()
elif (
self.config.get("rope_scaling", None) is not None
and self.config.get("rope_scaling", {}).get("type", "base") == "su"
):
elif scaling_type == "su":
self._init_to_su_rotary()
elif (
self.config.get("rope_scaling", None) is not None
and self.config.get("rope_scaling", {}).get("rope_type", "base") == "llama3"
):
elif scaling_type == "llama3":
self._init_to_get_llama3_rotary()
elif (
self.config.get("rope_scaling", None) is not None
and self.config.get("rope_scaling", {}).get("type", "base") == "mrope"
):
elif scaling_type == "mrope":
self._init_to_get_mrope_rotary()
else:
self._init_to_get_rotary()
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
return

def _init_weights(self):
Expand Down Expand Up @@ -269,7 +266,6 @@ def _init_to_get_yarn_rotary(self):
pos_freqs = base ** (torch.arange(0, dim, 2).float().cuda() / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scale * pos_freqs)

low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
inv_freq_mask = (
1 - linear_ramp_mask(low, high, dim // 2).float().cuda()
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/api_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req

prompt = await build_prompt(request, tools)
sampling_params_dict = {
"do_sample": request.do_sample,
"do_sample": True,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"temperature": request.temperature,
Expand Down