-
Notifications
You must be signed in to change notification settings - Fork 299
release and resume #1122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
release and resume #1122
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |||||||||||||||
| import json | ||||||||||||||||
| import torch | ||||||||||||||||
| import torch.nn.functional as F | ||||||||||||||||
| from typing import final | ||||||||||||||||
| from typing import final, List, Optional | ||||||||||||||||
| from tqdm import tqdm | ||||||||||||||||
|
|
||||||||||||||||
| from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights | ||||||||||||||||
|
|
@@ -30,6 +30,10 @@ | |||||||||||||||
| from lightllm.utils.envs_utils import set_model_init_status | ||||||||||||||||
| from lightllm.common.triton_utils.autotuner import Autotuner | ||||||||||||||||
| from lightllm.utils.infer_utils import post_empty_cache | ||||||||||||||||
| from lightllm.utils.torch_memory_saver_utils import ( | ||||||||||||||||
| TorchMemorySaverWrapper, | ||||||||||||||||
| MemoryTag, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| logger = init_logger(__name__) | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -88,6 +92,7 @@ def __init__(self, kvargs): | |||||||||||||||
| self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode | ||||||||||||||||
|
|
||||||||||||||||
| self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"] | ||||||||||||||||
| self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver) | ||||||||||||||||
|
|
||||||||||||||||
| self._init_datatype() | ||||||||||||||||
| self._init_config() | ||||||||||||||||
|
|
@@ -97,20 +102,29 @@ def __init__(self, kvargs): | |||||||||||||||
|
|
||||||||||||||||
| # 更连续的显存分配可以有更好的性能 | ||||||||||||||||
| if self.max_total_token_num is None: | ||||||||||||||||
| self._init_weights() | ||||||||||||||||
| self._init_mem_manager() | ||||||||||||||||
| with self.torch_memory_saver.region( | ||||||||||||||||
| tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup | ||||||||||||||||
| ): | ||||||||||||||||
| self._init_weights() | ||||||||||||||||
| with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): | ||||||||||||||||
| self._init_mem_manager() | ||||||||||||||||
| else: | ||||||||||||||||
| self._init_mem_manager() | ||||||||||||||||
| self._init_weights() | ||||||||||||||||
| with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): | ||||||||||||||||
| self._init_mem_manager() | ||||||||||||||||
| with self.torch_memory_saver.region( | ||||||||||||||||
| tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup | ||||||||||||||||
| ): | ||||||||||||||||
| self._init_weights() | ||||||||||||||||
|
|
||||||||||||||||
| self._init_kv_move_buffer() | ||||||||||||||||
| self._check_mem_size() | ||||||||||||||||
| self._init_req_manager() | ||||||||||||||||
| with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): | ||||||||||||||||
| self._init_req_manager() | ||||||||||||||||
| self._init_infer_layer() | ||||||||||||||||
| self._init_some_value() | ||||||||||||||||
| self._init_custom() | ||||||||||||||||
| self._init_inferstate_cls() | ||||||||||||||||
| self._autotune_warmup() | ||||||||||||||||
| # self._autotune_warmup() | ||||||||||||||||
| self._init_padded_req() | ||||||||||||||||
| # wait必须在init cudagraph 之前,避免错误捕获 | ||||||||||||||||
| self._wait_other_modules_ready() | ||||||||||||||||
|
|
@@ -179,11 +193,13 @@ def _init_weights(self): | |||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
| def load_weights(self, weight_dict: dict): | ||||||||||||||||
| load_hf_weights(self.data_type, | ||||||||||||||||
| self.weight_dir_, | ||||||||||||||||
| pre_post_layer=self.pre_post_weight, | ||||||||||||||||
| transformer_layer_list=self.trans_layers_weight, | ||||||||||||||||
| weight_dict=weight_dict) | ||||||||||||||||
| load_hf_weights( | ||||||||||||||||
| self.data_type, | ||||||||||||||||
| self.weight_dir_, | ||||||||||||||||
| pre_post_layer=self.pre_post_weight, | ||||||||||||||||
| transformer_layer_list=self.trans_layers_weight, | ||||||||||||||||
| weight_dict=weight_dict, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| def _init_mem_manager(self): | ||||||||||||||||
| assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 | ||||||||||||||||
|
|
@@ -766,6 +782,7 @@ def _check_max_len_infer(self): | |||||||||||||||
| ) | ||||||||||||||||
| logger.error(exception_str) | ||||||||||||||||
| raise Exception(exception_str) | ||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
| def autotune_layers(self): | ||||||||||||||||
|
|
@@ -896,6 +913,9 @@ def _init_padded_req(self): | |||||||||||||||
| del b_seq_len | ||||||||||||||||
| del b_ready_cache_len | ||||||||||||||||
| del model_output | ||||||||||||||||
| del b_mtp_index | ||||||||||||||||
| del b_prefill_start_loc | ||||||||||||||||
| del b_q_seq_len | ||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -911,3 +931,55 @@ def _gen_special_model_input(self, token_num: int): | |||||||||||||||
| special_model_input["deepseekv3_mtp_draft_input_hiddens"] = None | ||||||||||||||||
|
|
||||||||||||||||
| return special_model_input | ||||||||||||||||
|
|
||||||||||||||||
| def release_memory_occupation(self, tags: Optional[List[MemoryTag]]): | ||||||||||||||||
| if tags is None: | ||||||||||||||||
| self.release_all() | ||||||||||||||||
| return | ||||||||||||||||
| if MemoryTag.WEIGHT in tags: | ||||||||||||||||
| self.release_weight() | ||||||||||||||||
| if MemoryTag.KV_CACHE in tags: | ||||||||||||||||
| self.release_kv_cache() | ||||||||||||||||
| if MemoryTag.GRAPH in tags: | ||||||||||||||||
| self.release_graph() | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
| def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]): | ||||||||||||||||
| if tags is None: | ||||||||||||||||
| self.resume_all() | ||||||||||||||||
| return | ||||||||||||||||
| if MemoryTag.WEIGHT in tags: | ||||||||||||||||
| self.resume_weight() | ||||||||||||||||
| if MemoryTag.KV_CACHE in tags: | ||||||||||||||||
| self.resume_kv_cache() | ||||||||||||||||
| if MemoryTag.GRAPH in tags: | ||||||||||||||||
| self.resume_graph() | ||||||||||||||||
| return | ||||||||||||||||
|
Comment on lines
+935
to
+957
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The methods def release_memory_occupation(self, tags: Optional[List[MemoryTag]]):
if tags is None:
self.release_all()
return
actions = {
MemoryTag.WEIGHT: self.release_weight,
MemoryTag.KV_CACHE: self.release_kv_cache,
MemoryTag.GRAPH: self.release_graph,
}
for tag in tags:
if tag in actions:
actions[tag]()
return
def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]):
if tags is None:
self.resume_all()
return
actions = {
MemoryTag.WEIGHT: self.resume_weight,
MemoryTag.KV_CACHE: self.resume_kv_cache,
MemoryTag.GRAPH: self.resume_graph,
}
for tag in tags:
if tag in actions:
actions[tag]()
return |
||||||||||||||||
|
|
||||||||||||||||
| def release_weight(self): | ||||||||||||||||
| self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) | ||||||||||||||||
|
|
||||||||||||||||
| def release_kv_cache(self): | ||||||||||||||||
| self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) | ||||||||||||||||
|
|
||||||||||||||||
| def release_graph(self): | ||||||||||||||||
| self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) | ||||||||||||||||
|
|
||||||||||||||||
| def release_all(self): | ||||||||||||||||
| self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) | ||||||||||||||||
| self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) | ||||||||||||||||
| self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) | ||||||||||||||||
|
Comment on lines
+968
to
+971
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| def resume_weight(self): | ||||||||||||||||
| self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) | ||||||||||||||||
|
|
||||||||||||||||
| def resume_kv_cache(self): | ||||||||||||||||
| self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) | ||||||||||||||||
|
|
||||||||||||||||
| def resume_graph(self): | ||||||||||||||||
| self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) | ||||||||||||||||
|
|
||||||||||||||||
| def resume_all(self): | ||||||||||||||||
| self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) | ||||||||||||||||
| self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) | ||||||||||||||||
| self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) | ||||||||||||||||
|
Comment on lines
+982
to
+985
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for wrapping
_init_weights()and_init_mem_manager()withtorch_memory_saver.regionis duplicated in theifandelseblocks. This can be refactored to reduce code duplication and improve maintainability by defining the context managers once before the conditional block.