Skip to content

Conversation

@moehanabi
Copy link
Contributor

@moehanabi moehanabi commented Jan 8, 2026

What does this PR do?

Support router replay with sglang

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

maybe use with sgl-project/sglang#15751 if you want to set chunked_prefill_size = -1

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for router replay with sglang, which is useful for analyzing and debugging MoE models. The changes involve enabling the feature in the sglang server configuration and handling the routed_experts data in the agent loop and sglang server.

My review has identified a critical bug that could lead to a runtime crash due to an unhandled type for routed_experts. I've also pointed out a high-severity issue related to a local import that could cause runtime failures if an incompatible version of sglang is used. I've provided suggestions to fix both issues to improve the robustness of this new feature.

Comment on lines 332 to 344
else:
from sglang.srt.layers.moe.routed_experts_capturer import extract_routed_experts_from_meta_info

routed_experts = extract_routed_experts_from_meta_info(output).reshape(
-1, self.model_config.hf_config.num_hidden_layers, self.model_config.hf_config.num_experts_per_tok
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The local import of extract_routed_experts_from_meta_info inside the generate method can lead to a runtime ImportError if this function is not available in the installed sglang version. This can make debugging difficult as the dependency is not explicit at the module level. It's better to handle this dependency explicitly.

Consider moving the import to the top of the file and wrapping it in a try...except block. This makes the dependency clear and allows for a more graceful failure if the required sglang version is not installed.

You can add the following at the top of the file:

try:
    from sglang.srt.layers.moe.routed_experts_capturer import extract_routed_experts_from_meta_info
except ImportError:
    extract_routed_experts_from_meta_info = None

Then, you can modify the logic in generate to check if the import was successful:

Suggested change
else:
from sglang.srt.layers.moe.routed_experts_capturer import extract_routed_experts_from_meta_info
routed_experts = extract_routed_experts_from_meta_info(output).reshape(
-1, self.model_config.hf_config.num_hidden_layers, self.model_config.hf_config.num_experts_per_tok
)
else:
if extract_routed_experts_from_meta_info is None:
raise ImportError(
"`extract_routed_experts_from_meta_info` is not available. "
"Please check your sglang installation or version."
)
routed_experts = extract_routed_experts_from_meta_info(output).reshape(
-1, self.model_config.hf_config.num_hidden_layers, self.model_config.hf_config.num_experts_per_tok
)

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for router replay functionality with SGLang backend, enabling the capture and replay of MoE (Mixture of Experts) routing decisions during rollout generation.

Key changes:

  • Enable returning routed experts information from SGLang server when the feature is enabled
  • Extract and reshape routed experts data from generation output
  • Handle both numpy array and torch tensor types for routed experts in agent loop processing

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
verl/workers/rollout/sglang_rollout/async_sglang_server.py Adds configuration to enable routed experts return in SGLang, extracts and reshapes routed experts data from generation output
verl/experimental/agent_loop/agent_loop.py Adds type checking to handle routed experts as either numpy array or torch tensor

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 333 to 336
from sglang.srt.layers.moe.routed_experts_capturer import extract_routed_experts_from_meta_info

routed_experts = extract_routed_experts_from_meta_info(output).reshape(
-1, self.model_config.hf_config.num_hidden_layers, self.model_config.hf_config.num_experts_per_tok
Copy link

Copilot AI Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import statement is placed inside the conditional block, which is good for avoiding unnecessary imports when the feature is disabled. However, this import is from an external library (sglang) which may not have a stable API. Consider adding a try-except ImportError block around this import to provide a clearer error message if the sglang version doesn't support this function, since this is a relatively new feature.

Suggested change
from sglang.srt.layers.moe.routed_experts_capturer import extract_routed_experts_from_meta_info
routed_experts = extract_routed_experts_from_meta_info(output).reshape(
-1, self.model_config.hf_config.num_hidden_layers, self.model_config.hf_config.num_experts_per_tok
try:
from sglang.srt.layers.moe.routed_experts_capturer import (
extract_routed_experts_from_meta_info,
)
except ImportError as e:
raise ImportError(
"Failed to import 'extract_routed_experts_from_meta_info' from "
"'sglang.srt.layers.moe.routed_experts_capturer'. This feature "
"requires a version of 'sglang' that supports routed expert "
"capturing. Please upgrade 'sglang' or disable "
"'enable_rollout_routing_replay' in the rollout configuration."
) from e
routed_experts = extract_routed_experts_from_meta_info(output).reshape(
-1,
self.model_config.hf_config.num_hidden_layers,
self.model_config.hf_config.num_experts_per_tok,

Copilot uses AI. Check for mistakes.
Comment on lines 331 to 336
routed_experts = output.get("meta_info", {}).get("routed_experts", None)
else:
from sglang.srt.layers.moe.routed_experts_capturer import extract_routed_experts_from_meta_info

routed_experts = extract_routed_experts_from_meta_info(output).reshape(
-1, self.model_config.hf_config.num_hidden_layers, self.model_config.hf_config.num_experts_per_tok
Copy link

Copilot AI Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When skip_tokenizer_init is True, the code uses .get() with a default of None for routed_experts, but there's no validation to ensure the data is in the expected format if it exists. When the data exists, it should likely have the same shape requirements as in the else branch. Consider adding validation or reshaping logic for this path as well to ensure consistency.

Suggested change
routed_experts = output.get("meta_info", {}).get("routed_experts", None)
else:
from sglang.srt.layers.moe.routed_experts_capturer import extract_routed_experts_from_meta_info
routed_experts = extract_routed_experts_from_meta_info(output).reshape(
-1, self.model_config.hf_config.num_hidden_layers, self.model_config.hf_config.num_experts_per_tok
meta_info = output.get("meta_info", {}) or {}
routed_experts_raw = meta_info.get("routed_experts", None)
if routed_experts_raw is not None:
try:
routed_experts_tensor = torch.as_tensor(routed_experts_raw)
routed_experts = routed_experts_tensor.reshape(
-1,
self.model_config.hf_config.num_hidden_layers,
self.model_config.hf_config.num_experts_per_tok,
)
except (RuntimeError, ValueError) as e:
logger.warning(
"Failed to reshape routed_experts from meta_info with skip_tokenizer_init=True: %s",
e,
)
routed_experts = None
else:
from sglang.srt.layers.moe.routed_experts_capturer import extract_routed_experts_from_meta_info
routed_experts = extract_routed_experts_from_meta_info(output).reshape(
-1,
self.model_config.hf_config.num_hidden_layers,
self.model_config.hf_config.num_experts_per_tok,

Copilot uses AI. Check for mistakes.
@wuxibin89
Copy link
Collaborator

@moehanabi Thanks for you PR, could you also provide a training script and report some experiment result?

@moehanabi
Copy link
Contributor Author

@moehanabi Thanks for you PR, could you also provide a training script and report some experiment result?

Hi, I have updated an example and README doc.

I wonder what kind of experiment result you want, for sgl-project/sglang@bed301a did not supply experiment result either and it's a change of funtion rather than performance.

@wuxibin89
Copy link
Collaborator

wuxibin89 commented Jan 9, 2026

@moehanabi Thanks for you PR, could you also provide a training script and report some experiment result?

Hi, I have updated an example and README doc.

I wonder what kind of experiment result you want, for sgl-project/sglang@bed301a did not supply experiment result either and it's a change of funtion rather than performance.

Oh, I mean experiment result train/val reward metrics w/ and w/o router replay, not efficiency performance(MFU).

@wuxibin89 wuxibin89 merged commit e1cd47b into volcengine:main Jan 9, 2026
67 of 75 checks passed
@wuxibin89 wuxibin89 mentioned this pull request Jan 12, 2026
24 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants