Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
b636310
add /flush_cache (#1108)
shihaobai Nov 14, 2025
60c379e
Aborted reqs (#1113)
shihaobai Nov 18, 2025
4095831
flush cache mulit node (#1116)
shihaobai Nov 19, 2025
ca9325f
[bugfix]: flush cache in single node (#1118)
shihaobai Nov 19, 2025
9948925
add pause and continue (#1120)
shihaobai Nov 19, 2025
4b32287
add launch_server and StartArgs (#1119)
sufubao Nov 21, 2025
27abcf5
Update weight (#1127)
kingder Dec 1, 2025
c210c82
release and resume (#1122)
shihaobai Dec 1, 2025
094df8c
use portpicker (#1142)
sufubao Dec 8, 2025
560be02
Rl weight (#1143)
shihaobai Dec 8, 2025
3d225d7
add_cli
sufubao Nov 25, 2025
499074a
add 30b moe configs
shihaobai Dec 8, 2025
f737585
update requirement
shihaobai Dec 9, 2025
8a67a47
add-neo-chat
Dec 26, 2025
fdc1369
add-neo-chat
Dec 30, 2025
e8e7416
add-neo-chat
Dec 31, 2025
ba44983
add-neo-chat
Dec 31, 2025
4d41a33
add-neo-chat
Dec 31, 2025
0e8845c
fix-neo-chat
Jan 1, 2026
b48cd49
fix-neo-chat-position-ids-h
Jan 5, 2026
7a904f3
add-neo-chat-dense
Jan 6, 2026
4b757dd
add-neo-chat-dense
Jan 6, 2026
e208733
support verl.
Jan 8, 2026
245357c
improve0108
Jan 8, 2026
6503ac8
add min/max pixels sampling parameters
Jan 8, 2026
07df460
fix fused_moe not installed use pip.
Jan 12, 2026
a6f00fb
add visual nccl port alloc
shihaobai Jan 15, 2026
9360197
fix0115
Jan 15, 2026
920a741
fix0115
Jan 15, 2026
3aa5e18
fp8 online quant for moe
shihaobai Jan 16, 2026
7cb890b
hotfix for fa3 of llama
shihaobai Jan 16, 2026
c242a75
fp8w8a8 triton config
shihaobai Jan 19, 2026
a0195aa
fp16 config
shihaobai Jan 19, 2026
7f0c437
release ipc tensor early.
Jan 21, 2026
5738d9e
bugfix: fix flattened_bucket update weights
yqyao Jan 21, 2026
e11bf58
bugfix: fix update_weights from tensor
yqyao Jan 22, 2026
f767609
merge main
shihaobai Jan 28, 2026
ce76f8a
fix start
shihaobai Jan 29, 2026
45259ec
add-merge-kv-mode
Jan 29, 2026
da3b53d
add-neo-chat0129
Jan 29, 2026
1e066d0
Merge branch 'add-neo-chat-rebase' into rl_verl
Jan 29, 2026
043e898
moe fused weight
shihaobai Jan 30, 2026
52085a4
Merge branch 'rl_verl_rebase' of https://github.com/ModelTC/lightllm …
shihaobai Jan 30, 2026
80cfcc4
fix neo
shihaobai Jan 30, 2026
6bbdb4f
fix launch
shihaobai Jan 30, 2026
e436ba5
fix launch
shihaobai Jan 30, 2026
aef65bc
fix tp slice for merged moe weight
shihaobai Jan 30, 2026
bc87692
fix fusemoe weight
shihaobai Jan 30, 2026
cf5bcbf
fa3 for neo
shihaobai Jan 30, 2026
a23288b
fix dead visual process
shihaobai Jan 30, 2026
f558540
auto visual dp
shihaobai Jan 30, 2026
12c6c6b
fix format
shihaobai Jan 30, 2026
fd91cad
fix decode scale
Feb 2, 2026
2681263
add new mode support text_ids+image_ids
Feb 2, 2026
fd17aa0
add new mode support text_ids+image_ids
Feb 2, 2026
e516bd9
add cuda empty cache
shihaobai Feb 2, 2026
81a0c12
add invalid token ids to sampling_param for rl training
shihaobai Feb 2, 2026
14132d5
add unitest for apply_invalid_tokens
shihaobai Feb 2, 2026
ed41960
add gc collect
shihaobai Feb 3, 2026
706ae2e
logit_bias
shihaobai Feb 3, 2026
f432f5a
logit_bias
shihaobai Feb 3, 2026
92bf83a
Merge branch 'main' into rl_verl_rebase
shihaobai Feb 3, 2026
8f8ed44
merge main
shihaobai Feb 4, 2026
cac2edf
neo moe inferece speedup
shihaobai Feb 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lightllm/common/basemodel/attention/base_att.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class AttControl:
mla_prefill_dict: Dict = None
mla_decode: bool = False
mla_decode_dict: Dict = None
scale: float = None


@dataclass
Expand Down
7 changes: 5 additions & 2 deletions lightllm/common/basemodel/attention/fa3/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,11 @@ def _normal_decode_att(
sink_weight = None

k_descale, v_descale = None, None # disable quantization
Lq = q.shape[-1]
sm_scale = 1.0 / (Lq ** 0.5)
if att_control.scale is not None:
sm_scale = att_control.scale
else:
Lq = q.shape[-1]
sm_scale = 1.0 / (Lq ** 0.5)
o = flash_attn_with_kvcache(
q=q,
k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]),
Expand Down
101 changes: 89 additions & 12 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import torch
import torch.nn.functional as F
from typing import final, List
from typing import final, List, Optional
from tqdm import tqdm

from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
Expand All @@ -32,6 +32,10 @@
from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel
from lightllm.common.triton_utils.autotuner import Autotuner
from lightllm.utils.infer_utils import post_empty_cache
from lightllm.utils.torch_memory_saver_utils import (
TorchMemorySaverWrapper,
MemoryTag,
)
from .attention import get_prefill_att_backend_class, get_decode_att_backend_class
from .attention import BaseAttBackend

Expand Down Expand Up @@ -90,6 +94,7 @@ def __init__(self, kvargs):
self.tp_world_size_ = get_dp_world_size()
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode

self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver)
self.is_mtp_mode = self.args.mtp_mode in [
"vanilla_with_att",
"eagle_with_att",
Expand All @@ -103,15 +108,17 @@ def __init__(self, kvargs):
self._verify_params()
self._init_quant()

self._init_weights()
self._init_mem_manager()
self._init_kv_move_buffer()
with self.torch_memory_saver.region(tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup):
self._init_weights()
with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE):
self._init_mem_manager()
self._init_kv_move_buffer()
self._init_req_manager()
self._check_mem_size()
self._init_req_manager()
self._init_infer_layer()
self._init_some_value()
self._init_custom()
self._load_hf_weights()
self.load_weights(self.weight_dict)
# wait必须在init cudagraph 之前,避免错误捕获
self._wait_other_modules_ready()

Expand Down Expand Up @@ -176,17 +183,15 @@ def _init_weights(self, start_layer_index=0):
]
return

def _load_hf_weights(self):
def load_weights(self, weight_dict: dict):
assert weight_dict is None or isinstance(weight_dict, dict), "weight_dict must be a dict or None"
load_hf_weights(
self.data_type,
weight_dir=self.weight_dir_,
self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
weight_dict=self.weight_dict,
weight_dict=weight_dict,
)
self.pre_post_weight.verify_load()
[weight.verify_load() for weight in self.trans_layers_weight]
return

def _init_mem_manager(self):
assert self.config["num_attention_heads"] % self.tp_world_size_ == 0
Expand Down Expand Up @@ -884,6 +889,7 @@ def _check_max_len_infer(self):
)
logger.error(exception_str)
raise Exception(exception_str)
torch.cuda.empty_cache()
return

def autotune_layers(self):
Expand Down Expand Up @@ -1012,6 +1018,9 @@ def _init_padded_req(self):
del b_seq_len
del b_ready_cache_len
del model_output
del b_mtp_index
del b_prefill_start_loc
del b_q_seq_len
torch.cuda.empty_cache()
return

Expand All @@ -1032,3 +1041,71 @@ def _gen_special_model_input(self, token_num: int):
special_model_input["mtp_draft_input_hiddens"] = None

return special_model_input

def release_memory_occupation(self, tags: Optional[List[MemoryTag]]):
if tags is None:
self.release_all()
return
if MemoryTag.WEIGHT in tags:
self.release_weight()
if MemoryTag.KV_CACHE in tags:
self.release_kv_cache()
if MemoryTag.GRAPH in tags:
self.release_graph()
return

def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]):
if tags is None:
self.resume_all()
return
if MemoryTag.WEIGHT in tags:
self.resume_weight()
if MemoryTag.KV_CACHE in tags:
self.resume_kv_cache()
if MemoryTag.GRAPH in tags:
self.resume_graph()
return

def release_weight(self):
self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT)
torch.cuda.empty_cache()
gc.collect()

def release_kv_cache(self):
self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE)
torch.cuda.empty_cache()
gc.collect()

def release_graph(self):
self.torch_memory_saver.pause(tag=MemoryTag.GRAPH)
torch.cuda.empty_cache()
gc.collect()

def release_all(self):
self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT)
self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE)
self.torch_memory_saver.pause(tag=MemoryTag.GRAPH)
Comment on lines +1084 to +1087
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The release_all method can be made more concise and maintainable by iterating over a list of memory tags. This avoids code repetition and makes it easier to add or remove tags in the future.

Suggested change
def release_all(self):
self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT)
self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE)
self.torch_memory_saver.pause(tag=MemoryTag.GRAPH)
def release_all(self):
for tag in [MemoryTag.WEIGHT, MemoryTag.KV_CACHE, MemoryTag.GRAPH]:
self.torch_memory_saver.pause(tag=tag)

torch.cuda.empty_cache()
gc.collect()

def resume_weight(self):
torch.cuda.empty_cache()
gc.collect()
self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT)

def resume_kv_cache(self):
torch.cuda.empty_cache()
gc.collect()
self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE)

def resume_graph(self):
torch.cuda.empty_cache()
gc.collect()
self.torch_memory_saver.resume(tag=MemoryTag.GRAPH)

def resume_all(self):
torch.cuda.empty_cache()
gc.collect()
self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT)
self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE)
self.torch_memory_saver.resume(tag=MemoryTag.GRAPH)
Comment on lines 1106 to 1111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The resume_all method can be made more concise and maintainable by iterating over a list of memory tags. This avoids code repetition and makes it easier to add or remove tags in the future.

Suggested change
def resume_all(self):
self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT)
self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE)
self.torch_memory_saver.resume(tag=MemoryTag.GRAPH)
def resume_all(self):
for tag in [MemoryTag.WEIGHT, MemoryTag.KV_CACHE, MemoryTag.GRAPH]:
self.torch_memory_saver.resume(tag=tag)

9 changes: 7 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
from lightllm.utils.torch_memory_saver_utils import (
TorchMemorySaverWrapper,
MemoryTag,
)
from .infer_struct import InferStateInfo


Expand All @@ -24,6 +28,7 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192):
self.max_batch_size = max_batch_size
self.graph_max_len_in_batch = max_len_in_batch
self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap
self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver)

# gen cuda graph batch_sizes
# cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size]
Expand Down Expand Up @@ -89,7 +94,7 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo):
delattr(infer_state, param_name)

with lightllm_capture_graph(dist_group):
with torch.cuda.graph(graph_obj, pool=self.mempool):
with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool):
model_output = decode_func(infer_state)
self.graph[batch_size] = (graph_obj, infer_state, model_output)
graph_obj.replay()
Expand Down Expand Up @@ -127,7 +132,7 @@ def _capture_decode_overlap(

with lightllm_capture_graph(dist_group1):
with lightllm_capture_graph(dist_group):
with torch.cuda.graph(graph_obj, pool=self.mempool):
with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool):
model_output, model_output1 = decode_func(infer_state, infer_state1)
self.graph[batch_size] = (
graph_obj,
Expand Down
3 changes: 2 additions & 1 deletion lightllm/common/basemodel/layer_weights/hf_load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from tqdm import tqdm
import lightllm.utils.petrel_helper as utils
from lightllm.utils.dist_utils import get_current_device_id
from queue import Queue
from threading import Thread


def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None):
Expand Down Expand Up @@ -65,7 +67,6 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye
iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1)
desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers"
iterator = tqdm(iterator, total=len(candidate_files), desc=desc_str)

for _ in iterator:
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def _create_weight(self):
device_id=self.device_id_,
num_experts=self.local_n_routed_experts,
)
self.w1, self.w3 = w13_param_list
self.w1_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[0])
self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1])
self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2)
Expand All @@ -312,6 +313,8 @@ def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[st
for expert_idx, local_expert_idx in expert_idx_to_local_idx.items():
with self.lock:
self._load_expert(expert_idx, local_expert_idx, weights)
# for rl updated weight
self._load_merge_weight(weights)
self._load_expert_scale(
expert_idx,
local_expert_idx,
Expand All @@ -332,6 +335,7 @@ def _load_expert(
w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}"
w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}"
w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}"

row_slice_func = self.row_slicer._slice_weight
col_slice_func = self.col_slicer._slice_weight
if w1_weight in weights:
Expand All @@ -341,6 +345,19 @@ def _load_expert(
if w2_weight in weights:
self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx])

def _load_merge_weight(self, weights: Dict[str, torch.Tensor]):
w1_merge_weight = f"{self.weight_prefix}.{self.w1_weight_name}"
w2_merge_weight = f"{self.weight_prefix}.{self.w2_weight_name}"
w3_merge_weight = f"{self.weight_prefix}.{self.w3_weight_name}"
row_slice_func = self.row_slicer._slice_weight
col_slice_func = self.col_slicer._slice_weight
if w1_merge_weight in weights:
self.quant_method.load_weight(row_slice_func(weights[w1_merge_weight]), self.w1)
if w2_merge_weight in weights:
self.quant_method.load_weight(col_slice_func(weights[w2_merge_weight]), self.w2)
if w3_merge_weight in weights:
self.quant_method.load_weight(row_slice_func(weights[w3_merge_weight]), self.w3)

def _load_expert_scale(
self,
expert_idx: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,17 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten


# 默认weight 的shape是 outxin,这也是目前最通用的约定。
# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分
# 这里约定row-wise沿着倒数第二维切分,col-wise沿着第一维切分
class RowSliceMixin(SliceMixinTpl):
def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1):
super().__init__(tp_rank, tp_world_size, repeat_times)

def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
assert (
weight.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0
), f"tp slice error {weight.shape[0] * self.repeat_times_} % {self.tp_world_size_}"
start, end = self._get_slice_start_end(weight.shape[0])
return weight[start:end, :]
weight.shape[-2] * self.repeat_times_ % self.tp_world_size_ == 0
), f"tp slice error {weight.shape[-2] * self.repeat_times_} % {self.tp_world_size_}"
start, end = self._get_slice_start_end(weight.shape[-2])
return weight[..., start:end, :]

def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
assert (
Expand All @@ -75,17 +75,17 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times:

def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
assert (
weight_scale.shape[0] % self.tp_world_size_ == 0
), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}"
start, end = self._get_slice_start_end(weight_scale.shape[0])
return weight_scale[start:end]
weight_scale.shape[-2] % self.tp_world_size_ == 0
), f"tp slice error {weight_scale.shape[-2]} % {self.tp_world_size_}"
start, end = self._get_slice_start_end(weight_scale.shape[-2])
return weight_scale[..., start:end, :]

def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
assert (
weight_zero_point.shape[0] % self.tp_world_size_ == 0
), f"tp slice error {weight_zero_point.shape[0]} % {self.tp_world_size_}"
start, end = self._get_slice_start_end(weight_zero_point.shape[0])
return weight_zero_point[start:end]
weight_zero_point.shape[-2] % self.tp_world_size_ == 0
), f"tp slice error {weight_zero_point.shape[-2]} % {self.tp_world_size_}"
start, end = self._get_slice_start_end(weight_zero_point.shape[-2])
return weight_zero_point[..., start:end, :]


class ColSliceMixin(SliceMixinTpl):
Expand All @@ -94,10 +94,10 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times:

def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
assert (
weight.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0
), f"tp slice error {weight.shape[1] * self.repeat_times_ } % {self.tp_world_size_}"
start, end = self._get_slice_start_end(weight.shape[1])
return weight[:, start:end]
weight.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0
), f"tp slice error {weight.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}"
start, end = self._get_slice_start_end(weight.shape[-1])
return weight[..., start:end]

def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
return bias / self.tp_world_size_ * self.repeat_times_
Expand All @@ -110,16 +110,16 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times:
def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
assert (
weight_scale.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0
), f"tp slice error {weight_scale.shape[1] * self.repeat_times_ } % {self.tp_world_size_}"
start, end = self._get_slice_start_end(weight_scale.shape[1])
return weight_scale[:, start:end]
), f"tp slice error {weight_scale.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}"
start, end = self._get_slice_start_end(weight_scale.shape[-1])
return weight_scale[..., start:end]

def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor:
assert (
weight_zero_point.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0
), f"tp slice error {weight_zero_point.shape[1] * self.repeat_times_ } % {self.tp_world_size_}"
start, end = self._get_slice_start_end(weight_zero_point.shape[1])
return weight_zero_point[:, start:end]
weight_zero_point.shape[-1] * self.repeat_times_ % self.tp_world_size_ == 0
), f"tp slice error {weight_zero_point.shape[-1] * self.repeat_times_ } % {self.tp_world_size_}"
start, end = self._get_slice_start_end(weight_zero_point.shape[-1])
return weight_zero_point[..., start:end]


# awq 的量化权重是inxout存储格式,需要定制实现。
Expand Down
Loading