Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 84 additions & 12 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import torch
import torch.nn.functional as F
from typing import final
from typing import final, List, Optional
from tqdm import tqdm

from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
Expand All @@ -30,6 +30,10 @@
from lightllm.utils.envs_utils import set_model_init_status
from lightllm.common.triton_utils.autotuner import Autotuner
from lightllm.utils.infer_utils import post_empty_cache
from lightllm.utils.torch_memory_saver_utils import (
TorchMemorySaverWrapper,
MemoryTag,
)

logger = init_logger(__name__)

Expand Down Expand Up @@ -88,6 +92,7 @@ def __init__(self, kvargs):
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode

self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"]
self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver)

self._init_datatype()
self._init_config()
Expand All @@ -97,20 +102,29 @@ def __init__(self, kvargs):

# 更连续的显存分配可以有更好的性能
if self.max_total_token_num is None:
self._init_weights()
self._init_mem_manager()
with self.torch_memory_saver.region(
tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup
):
self._init_weights()
with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE):
self._init_mem_manager()
else:
self._init_mem_manager()
self._init_weights()
with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE):
self._init_mem_manager()
with self.torch_memory_saver.region(
tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup
):
self._init_weights()
Comment on lines 99 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for wrapping _init_weights() and _init_mem_manager() with torch_memory_saver.region is duplicated in the if and else blocks. This can be refactored to reduce code duplication and improve maintainability by defining the context managers once before the conditional block.

Suggested change
if self.max_total_token_num is None:
self._init_weights()
self._init_mem_manager()
with self.torch_memory_saver.region(
tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup
):
self._init_weights()
with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE):
self._init_mem_manager()
else:
self._init_mem_manager()
self._init_weights()
with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE):
self._init_mem_manager()
with self.torch_memory_saver.region(
tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup
):
self._init_weights()
weight_region = self.torch_memory_saver.region(
tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup
)
kv_cache_region = self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE)
if self.max_total_token_num is None:
with weight_region:
self._init_weights()
with kv_cache_region:
self._init_mem_manager()
else:
with kv_cache_region:
self._init_mem_manager()
with weight_region:
self._init_weights()


self._init_kv_move_buffer()
self._check_mem_size()
self._init_req_manager()
with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE):
self._init_req_manager()
self._init_infer_layer()
self._init_some_value()
self._init_custom()
self._init_inferstate_cls()
self._autotune_warmup()
# self._autotune_warmup()
self._init_padded_req()
# wait必须在init cudagraph 之前,避免错误捕获
self._wait_other_modules_ready()
Expand Down Expand Up @@ -179,11 +193,13 @@ def _init_weights(self):
return

def load_weights(self, weight_dict: dict):
load_hf_weights(self.data_type,
self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
weight_dict=weight_dict)
load_hf_weights(
self.data_type,
self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
weight_dict=weight_dict,
)

def _init_mem_manager(self):
assert self.config["num_attention_heads"] % self.tp_world_size_ == 0
Expand Down Expand Up @@ -766,6 +782,7 @@ def _check_max_len_infer(self):
)
logger.error(exception_str)
raise Exception(exception_str)
torch.cuda.empty_cache()
return

def autotune_layers(self):
Expand Down Expand Up @@ -896,6 +913,9 @@ def _init_padded_req(self):
del b_seq_len
del b_ready_cache_len
del model_output
del b_mtp_index
del b_prefill_start_loc
del b_q_seq_len
torch.cuda.empty_cache()
return

Expand All @@ -911,3 +931,55 @@ def _gen_special_model_input(self, token_num: int):
special_model_input["deepseekv3_mtp_draft_input_hiddens"] = None

return special_model_input

def release_memory_occupation(self, tags: Optional[List[MemoryTag]]):
if tags is None:
self.release_all()
return
if MemoryTag.WEIGHT in tags:
self.release_weight()
if MemoryTag.KV_CACHE in tags:
self.release_kv_cache()
if MemoryTag.GRAPH in tags:
self.release_graph()
return

def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]):
if tags is None:
self.resume_all()
return
if MemoryTag.WEIGHT in tags:
self.resume_weight()
if MemoryTag.KV_CACHE in tags:
self.resume_kv_cache()
if MemoryTag.GRAPH in tags:
self.resume_graph()
return
Comment on lines +935 to +957
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The methods release_memory_occupation and resume_memory_occupation use a series of if statements. This can be refactored into a more data-driven approach using a dictionary mapping memory tags to functions. This improves scalability and maintainability.

    def release_memory_occupation(self, tags: Optional[List[MemoryTag]]):
        if tags is None:
            self.release_all()
            return

        actions = {
            MemoryTag.WEIGHT: self.release_weight,
            MemoryTag.KV_CACHE: self.release_kv_cache,
            MemoryTag.GRAPH: self.release_graph,
        }
        for tag in tags:
            if tag in actions:
                actions[tag]()
        return

    def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]):
        if tags is None:
            self.resume_all()
            return

        actions = {
            MemoryTag.WEIGHT: self.resume_weight,
            MemoryTag.KV_CACHE: self.resume_kv_cache,
            MemoryTag.GRAPH: self.resume_graph,
        }
        for tag in tags:
            if tag in actions:
                actions[tag]()
        return


def release_weight(self):
self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT)

def release_kv_cache(self):
self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE)

def release_graph(self):
self.torch_memory_saver.pause(tag=MemoryTag.GRAPH)

def release_all(self):
self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT)
self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE)
self.torch_memory_saver.pause(tag=MemoryTag.GRAPH)
Comment on lines +968 to +971
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The release_all method contains repetitive calls to self.torch_memory_saver.pause. This can be simplified by iterating over a list of memory tags.

Suggested change
def release_all(self):
self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT)
self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE)
self.torch_memory_saver.pause(tag=MemoryTag.GRAPH)
def release_all(self):
for tag in [MemoryTag.WEIGHT, MemoryTag.KV_CACHE, MemoryTag.GRAPH]:
self.torch_memory_saver.pause(tag=tag)


def resume_weight(self):
self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT)

def resume_kv_cache(self):
self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE)

def resume_graph(self):
self.torch_memory_saver.resume(tag=MemoryTag.GRAPH)

def resume_all(self):
self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT)
self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE)
self.torch_memory_saver.resume(tag=MemoryTag.GRAPH)
Comment on lines +982 to +985
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The resume_all method contains repetitive calls to self.torch_memory_saver.resume. This can be simplified by iterating over a list of memory tags.

Suggested change
def resume_all(self):
self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT)
self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE)
self.torch_memory_saver.resume(tag=MemoryTag.GRAPH)
def resume_all(self):
for tag in [MemoryTag.WEIGHT, MemoryTag.KV_CACHE, MemoryTag.GRAPH]:
self.torch_memory_saver.resume(tag=tag)

9 changes: 7 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
from lightllm.utils.torch_memory_saver_utils import (
TorchMemorySaverWrapper,
MemoryTag,
)
from .infer_struct import InferStateInfo


Expand All @@ -24,6 +28,7 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192):
self.max_batch_size = max_batch_size
self.graph_max_len_in_batch = max_len_in_batch
self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap
self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver)

# gen cuda graph batch_sizes
# cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size]
Expand Down Expand Up @@ -82,7 +87,7 @@ def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: Inf
torch.cuda.synchronize()

with lightllm_capture_graph(dist_group):
with torch.cuda.graph(graph_obj, pool=self.mempool):
with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool):
model_output = decode_func(input_ids, infer_state)
self.graph[batch_size] = (graph_obj, input_ids, infer_state, model_output)
graph_obj.replay()
Expand Down Expand Up @@ -111,7 +116,7 @@ def _capture_decode_overlap(
torch.cuda.synchronize()
with lightllm_capture_graph(dist_group1):
with lightllm_capture_graph(dist_group):
with torch.cuda.graph(graph_obj, pool=self.mempool):
with self.torch_memory_saver.cuda_graph(graph_obj, pool=self.mempool):
model_output, model_output1 = decode_func(input_ids, infer_state, input_ids1, infer_state1)
self.graph[batch_size] = (
graph_obj,
Expand Down
70 changes: 69 additions & 1 deletion lightllm/common/basemodel/layer_weights/hf_load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from tqdm import tqdm
import lightllm.utils.petrel_helper as utils
from lightllm.utils.dist_utils import get_current_device_id
from queue import Queue
from threading import Thread


def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None):
Expand All @@ -28,7 +30,7 @@ def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_lay
gc.collect()


def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None):
def load_hf_weights_old(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None):
if isinstance(data_type, str):
data_type = torch.float16 if data_type == "fp16" else torch.float32
if pre_post_layer is not None:
Expand Down Expand Up @@ -70,3 +72,69 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye
pass

return


def _read_file(file_, use_safetensors, weight_dir):
if use_safetensors:
weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu")
weights = {k: weights.get_tensor(k) for k in weights.keys()}
else:
weights = utils.PetrelHelper.load(os.path.join(weight_dir, file_), map_location="cpu")

return weights


def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None):
if isinstance(data_type, str):
data_type = torch.float16 if data_type == "fp16" else torch.float32
if pre_post_layer is not None:
assert pre_post_layer.data_type_ == data_type, "type is not right"
if transformer_layer_list is not None:
assert transformer_layer_list[0].data_type_ == data_type, "type is not right"
if weight_dict:
if pre_post_layer is not None:
pre_post_layer.load_hf_weights(weight_dict)
if transformer_layer_list is not None:
for layer in transformer_layer_list:
layer.load_hf_weights(weight_dict)
del weight_dict
return
use_safetensors = True
files = utils.PetrelHelper.list(weight_dir, extension="all")
candidate_files = list(filter(lambda x: x.endswith(".safetensors"), files))
if len(candidate_files) == 0:
use_safetensors = False
candidate_files = list(filter(lambda x: x.endswith(".bin"), files))
assert len(candidate_files) != 0, "can only support pytorch tensor and safetensors format for weights."

weight_queue = Queue(maxsize=5) # 控制内存使用

def producer(chunk):
for file_ in chunk:
weights = _read_file(file_, use_safetensors, weight_dir)
weight_queue.put(weights)

LOADWORKER = int(os.environ.get("LOADWORKER", 1))

num_producers = min(LOADWORKER, len(candidate_files)) # 生产者数量
chunk_size = (len(candidate_files) + num_producers - 1) // num_producers
file_chunks = [candidate_files[i : i + chunk_size] for i in range(0, len(candidate_files), chunk_size)]

producer_threads = []
for i, chunk in enumerate(file_chunks):
thread = Thread(target=producer, args=(chunk,), name=f"Producer-{i}")
thread.start()
producer_threads.append(thread)

for _ in tqdm(range(len(candidate_files)), desc="Loading weights"):
weights = weight_queue.get()
if pre_post_layer is not None:
pre_post_layer.load_hf_weights(weights)
if transformer_layer_list is not None:
for layer in transformer_layer_list:
layer.load_hf_weights(weights)
del weights
gc.collect()

for thread in producer_threads:
thread.join()
6 changes: 6 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,4 +537,10 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used."""
)
parser.add_argument(
"--enable_torch_memory_saver",
action="store_true",
help="""enable torch memory saver, which is used for release_memory and resume_memory during RL training.""",
)
parser.add_argument("--enable_weight_cpu_backup", action="store_true", help="""enable weight cpu backup.""")
return parser
Loading