-
Notifications
You must be signed in to change notification settings - Fork 299
feat: RL training support for VERL #1196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
shihaobai
wants to merge
64
commits into
main
Choose a base branch
from
rl_verl_rebase
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+5,513
−297
Open
Changes from all commits
Commits
Show all changes
64 commits
Select commit
Hold shift + click to select a range
b636310
add /flush_cache (#1108)
shihaobai 60c379e
Aborted reqs (#1113)
shihaobai 4095831
flush cache mulit node (#1116)
shihaobai ca9325f
[bugfix]: flush cache in single node (#1118)
shihaobai 9948925
add pause and continue (#1120)
shihaobai 4b32287
add launch_server and StartArgs (#1119)
sufubao 27abcf5
Update weight (#1127)
kingder c210c82
release and resume (#1122)
shihaobai 094df8c
use portpicker (#1142)
sufubao 560be02
Rl weight (#1143)
shihaobai 3d225d7
add_cli
sufubao 499074a
add 30b moe configs
shihaobai f737585
update requirement
shihaobai 8a67a47
add-neo-chat
fdc1369
add-neo-chat
e8e7416
add-neo-chat
ba44983
add-neo-chat
4d41a33
add-neo-chat
0e8845c
fix-neo-chat
b48cd49
fix-neo-chat-position-ids-h
7a904f3
add-neo-chat-dense
4b757dd
add-neo-chat-dense
e208733
support verl.
245357c
improve0108
6503ac8
add min/max pixels sampling parameters
07df460
fix fused_moe not installed use pip.
a6f00fb
add visual nccl port alloc
shihaobai 9360197
fix0115
920a741
fix0115
3aa5e18
fp8 online quant for moe
shihaobai 7cb890b
hotfix for fa3 of llama
shihaobai c242a75
fp8w8a8 triton config
shihaobai a0195aa
fp16 config
shihaobai 7f0c437
release ipc tensor early.
5738d9e
bugfix: fix flattened_bucket update weights
yqyao e11bf58
bugfix: fix update_weights from tensor
yqyao f767609
merge main
shihaobai ce76f8a
fix start
shihaobai 45259ec
add-merge-kv-mode
da3b53d
add-neo-chat0129
1e066d0
Merge branch 'add-neo-chat-rebase' into rl_verl
043e898
moe fused weight
shihaobai 52085a4
Merge branch 'rl_verl_rebase' of https://github.com/ModelTC/lightllm …
shihaobai 80cfcc4
fix neo
shihaobai 6bbdb4f
fix launch
shihaobai e436ba5
fix launch
shihaobai aef65bc
fix tp slice for merged moe weight
shihaobai bc87692
fix fusemoe weight
shihaobai cf5bcbf
fa3 for neo
shihaobai a23288b
fix dead visual process
shihaobai f558540
auto visual dp
shihaobai 12c6c6b
fix format
shihaobai fd91cad
fix decode scale
2681263
add new mode support text_ids+image_ids
fd17aa0
add new mode support text_ids+image_ids
e516bd9
add cuda empty cache
shihaobai 81a0c12
add invalid token ids to sampling_param for rl training
shihaobai 14132d5
add unitest for apply_invalid_tokens
shihaobai ed41960
add gc collect
shihaobai 706ae2e
logit_bias
shihaobai f432f5a
logit_bias
shihaobai 92bf83a
Merge branch 'main' into rl_verl_rebase
shihaobai 8f8ed44
merge main
shihaobai cac2edf
neo moe inferece speedup
shihaobai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |||||||||||||||
| import json | ||||||||||||||||
| import torch | ||||||||||||||||
| import torch.nn.functional as F | ||||||||||||||||
| from typing import final, List | ||||||||||||||||
| from typing import final, List, Optional | ||||||||||||||||
| from tqdm import tqdm | ||||||||||||||||
|
|
||||||||||||||||
| from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights | ||||||||||||||||
|
|
@@ -32,6 +32,10 @@ | |||||||||||||||
| from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel | ||||||||||||||||
| from lightllm.common.triton_utils.autotuner import Autotuner | ||||||||||||||||
| from lightllm.utils.infer_utils import post_empty_cache | ||||||||||||||||
| from lightllm.utils.torch_memory_saver_utils import ( | ||||||||||||||||
| TorchMemorySaverWrapper, | ||||||||||||||||
| MemoryTag, | ||||||||||||||||
| ) | ||||||||||||||||
| from .attention import get_prefill_att_backend_class, get_decode_att_backend_class | ||||||||||||||||
| from .attention import BaseAttBackend | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -90,6 +94,7 @@ def __init__(self, kvargs): | |||||||||||||||
| self.tp_world_size_ = get_dp_world_size() | ||||||||||||||||
| self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode | ||||||||||||||||
|
|
||||||||||||||||
| self.torch_memory_saver = TorchMemorySaverWrapper(self.args.enable_torch_memory_saver) | ||||||||||||||||
| self.is_mtp_mode = self.args.mtp_mode in [ | ||||||||||||||||
| "vanilla_with_att", | ||||||||||||||||
| "eagle_with_att", | ||||||||||||||||
|
|
@@ -103,15 +108,17 @@ def __init__(self, kvargs): | |||||||||||||||
| self._verify_params() | ||||||||||||||||
| self._init_quant() | ||||||||||||||||
|
|
||||||||||||||||
| self._init_weights() | ||||||||||||||||
| self._init_mem_manager() | ||||||||||||||||
| self._init_kv_move_buffer() | ||||||||||||||||
| with self.torch_memory_saver.region(tag=MemoryTag.WEIGHT, enable_cpu_backup=self.args.enable_weight_cpu_backup): | ||||||||||||||||
| self._init_weights() | ||||||||||||||||
| with self.torch_memory_saver.region(tag=MemoryTag.KV_CACHE): | ||||||||||||||||
| self._init_mem_manager() | ||||||||||||||||
| self._init_kv_move_buffer() | ||||||||||||||||
| self._init_req_manager() | ||||||||||||||||
| self._check_mem_size() | ||||||||||||||||
| self._init_req_manager() | ||||||||||||||||
| self._init_infer_layer() | ||||||||||||||||
| self._init_some_value() | ||||||||||||||||
| self._init_custom() | ||||||||||||||||
| self._load_hf_weights() | ||||||||||||||||
| self.load_weights(self.weight_dict) | ||||||||||||||||
| # wait必须在init cudagraph 之前,避免错误捕获 | ||||||||||||||||
| self._wait_other_modules_ready() | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -176,17 +183,15 @@ def _init_weights(self, start_layer_index=0): | |||||||||||||||
| ] | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
| def _load_hf_weights(self): | ||||||||||||||||
| def load_weights(self, weight_dict: dict): | ||||||||||||||||
| assert weight_dict is None or isinstance(weight_dict, dict), "weight_dict must be a dict or None" | ||||||||||||||||
| load_hf_weights( | ||||||||||||||||
| self.data_type, | ||||||||||||||||
| weight_dir=self.weight_dir_, | ||||||||||||||||
| self.weight_dir_, | ||||||||||||||||
| pre_post_layer=self.pre_post_weight, | ||||||||||||||||
| transformer_layer_list=self.trans_layers_weight, | ||||||||||||||||
| weight_dict=self.weight_dict, | ||||||||||||||||
| weight_dict=weight_dict, | ||||||||||||||||
| ) | ||||||||||||||||
| self.pre_post_weight.verify_load() | ||||||||||||||||
| [weight.verify_load() for weight in self.trans_layers_weight] | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
| def _init_mem_manager(self): | ||||||||||||||||
| assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 | ||||||||||||||||
|
|
@@ -884,6 +889,7 @@ def _check_max_len_infer(self): | |||||||||||||||
| ) | ||||||||||||||||
| logger.error(exception_str) | ||||||||||||||||
| raise Exception(exception_str) | ||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
| def autotune_layers(self): | ||||||||||||||||
|
|
@@ -1012,6 +1018,9 @@ def _init_padded_req(self): | |||||||||||||||
| del b_seq_len | ||||||||||||||||
| del b_ready_cache_len | ||||||||||||||||
| del model_output | ||||||||||||||||
| del b_mtp_index | ||||||||||||||||
| del b_prefill_start_loc | ||||||||||||||||
| del b_q_seq_len | ||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -1032,3 +1041,71 @@ def _gen_special_model_input(self, token_num: int): | |||||||||||||||
| special_model_input["mtp_draft_input_hiddens"] = None | ||||||||||||||||
|
|
||||||||||||||||
| return special_model_input | ||||||||||||||||
|
|
||||||||||||||||
| def release_memory_occupation(self, tags: Optional[List[MemoryTag]]): | ||||||||||||||||
| if tags is None: | ||||||||||||||||
| self.release_all() | ||||||||||||||||
| return | ||||||||||||||||
| if MemoryTag.WEIGHT in tags: | ||||||||||||||||
| self.release_weight() | ||||||||||||||||
| if MemoryTag.KV_CACHE in tags: | ||||||||||||||||
| self.release_kv_cache() | ||||||||||||||||
| if MemoryTag.GRAPH in tags: | ||||||||||||||||
| self.release_graph() | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
| def resume_memory_occupation(self, tags: Optional[List[MemoryTag]]): | ||||||||||||||||
| if tags is None: | ||||||||||||||||
| self.resume_all() | ||||||||||||||||
| return | ||||||||||||||||
| if MemoryTag.WEIGHT in tags: | ||||||||||||||||
| self.resume_weight() | ||||||||||||||||
| if MemoryTag.KV_CACHE in tags: | ||||||||||||||||
| self.resume_kv_cache() | ||||||||||||||||
| if MemoryTag.GRAPH in tags: | ||||||||||||||||
| self.resume_graph() | ||||||||||||||||
| return | ||||||||||||||||
|
|
||||||||||||||||
| def release_weight(self): | ||||||||||||||||
| self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) | ||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||
| gc.collect() | ||||||||||||||||
|
|
||||||||||||||||
| def release_kv_cache(self): | ||||||||||||||||
| self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) | ||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||
| gc.collect() | ||||||||||||||||
|
|
||||||||||||||||
| def release_graph(self): | ||||||||||||||||
| self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) | ||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||
| gc.collect() | ||||||||||||||||
|
|
||||||||||||||||
| def release_all(self): | ||||||||||||||||
| self.torch_memory_saver.pause(tag=MemoryTag.WEIGHT) | ||||||||||||||||
| self.torch_memory_saver.pause(tag=MemoryTag.KV_CACHE) | ||||||||||||||||
| self.torch_memory_saver.pause(tag=MemoryTag.GRAPH) | ||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||
| gc.collect() | ||||||||||||||||
|
|
||||||||||||||||
| def resume_weight(self): | ||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||
| gc.collect() | ||||||||||||||||
| self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) | ||||||||||||||||
|
|
||||||||||||||||
| def resume_kv_cache(self): | ||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||
| gc.collect() | ||||||||||||||||
| self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) | ||||||||||||||||
|
|
||||||||||||||||
| def resume_graph(self): | ||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||
| gc.collect() | ||||||||||||||||
| self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) | ||||||||||||||||
|
|
||||||||||||||||
| def resume_all(self): | ||||||||||||||||
| torch.cuda.empty_cache() | ||||||||||||||||
| gc.collect() | ||||||||||||||||
| self.torch_memory_saver.resume(tag=MemoryTag.WEIGHT) | ||||||||||||||||
| self.torch_memory_saver.resume(tag=MemoryTag.KV_CACHE) | ||||||||||||||||
| self.torch_memory_saver.resume(tag=MemoryTag.GRAPH) | ||||||||||||||||
|
Comment on lines
1106
to
1111
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
release_allmethod can be made more concise and maintainable by iterating over a list of memory tags. This avoids code repetition and makes it easier to add or remove tags in the future.