Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions specforge/modeling/target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
HFEagle3TargetModel,
SGLangEagle3TargetEngine,
SGLangEagle3TargetModel,
SGLangServerEagle3TargetEngine,
get_eagle3_target_model,
)
from .factory import available_target_engines, get_target_engine
Expand All @@ -24,6 +25,7 @@
"SGLangEagle3TargetEngine",
"HFEagle3TargetEngine",
"CustomEagle3TargetEngine",
"SGLangServerEagle3TargetEngine",
"get_eagle3_target_model",
# Back-compat aliases (pre-Phase-B names)
"Eagle3TargetModel",
Expand Down
166 changes: 31 additions & 135 deletions specforge/modeling/target/dflash_target_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,17 @@
from typing import List, Optional

import torch
import torch.distributed as dist
import torch.nn as nn
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.mem_cache.cache_init_params import CacheInitParams
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import require_mlp_sync, require_mlp_tp_gather
from transformers import AutoModelForCausalLM

from specforge.distributed import get_tp_group

from .base import TargetEngine
from .sglang_backend import SGLangRunner

# NOTE (Phase B2): this module no longer imports sglang internals. The
# SGLang-version-pinned capture path (ServerArgs / ModelConfig / SGLangRunner +
# the extend/capture forward) lives entirely in
# ``sglang_backend.SGLangCaptureBackend``, shared with the eagle3 engine (one
# copy of the forward + mlp-sync). The SGLang engine below composes it, imported
# lazily inside ``from_pretrained`` so ``import specforge`` stays sglang-agnostic.


@dataclass
Expand Down Expand Up @@ -92,9 +85,14 @@ class SGLangDFlashTargetEngine(DFlashTargetEngine):

backend = "sglang"

def __init__(self, model_runner: SGLangRunner):
super().__init__()
self.model_runner = model_runner
def __init__(self, backend): # backend: sglang_backend.SGLangCaptureBackend
super().__init__() # capture_layer_ids = None
self._backend = backend

@property
def model_runner(self):
"""Kept for back-compat: the underlying sglang ModelRunner."""
return self._backend.model_runner

@classmethod
def from_pretrained(
Expand All @@ -106,106 +104,24 @@ def from_pretrained(
trust_remote_code: bool = False,
**kwargs,
) -> "SGLangDFlashTargetEngine":
tp_size = dist.get_world_size(get_tp_group())
server_args = ServerArgs(
model_path=pretrained_model_name_or_path,
# Lazy import so `import specforge` still works without the pinned sglang:
# the sglang-version coupling lives entirely in SGLangCaptureBackend, which
# also unifies the extend/mlp-sync forward this engine used to duplicate.
from .sglang_backend import SGLangCaptureBackend

backend = SGLangCaptureBackend.build(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
dtype=torch_dtype,
enable_return_hidden_states=True, # Critical for DFlash
disable_cuda_graph=True,
tp_size=tp_size,
pp_size=1,
wrap_eagle3_logits=False,
**kwargs,
)

tp_rank = dist.get_rank(get_tp_group())
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
model_config = ModelConfig.from_server_args(server_args)

model_runner = SGLangRunner(
model_config=model_config,
mem_fraction_static=server_args.mem_fraction_static,
gpu_id=torch.cuda.current_device(),
tp_rank=dist.get_rank(get_tp_group()),
tp_size=server_args.tp_size,
moe_ep_rank=moe_ep_rank,
moe_ep_size=server_args.ep_size,
pp_rank=0,
pp_size=1,
server_args=server_args,
nccl_port=None,
)
return cls(model_runner)
return cls(backend)

def set_capture_layers(self, layer_ids: List[int]) -> None:
super().set_capture_layers(layer_ids)
if hasattr(self.model_runner.model, "set_eagle3_layers_to_capture"):
self.model_runner.model.set_eagle3_layers_to_capture(layer_ids)
print(self.model_runner.model.model.layers_to_capture)

@torch.no_grad
def _extend(self, reqs):
cache_params = CacheInitParams(
disable=False,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
page_size=self.model_runner.server_args.page_size,
)
tree_cache = RadixCache(cache_params)

batch = ScheduleBatch.init_new(
reqs=reqs,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
tree_cache=tree_cache,
model_config=self.model_runner.model_config,
enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE,
)
batch.prepare_for_extend()

if require_mlp_sync(self.model_runner.server_args):
Scheduler.prepare_mlp_sync_batch_raw(
batch,
dp_size=self.model_runner.server_args.dp_size,
attn_tp_size=1,
tp_group=self.model_runner.tp_group,
get_idle_batch=None,
disable_cuda_graph=self.model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None,
require_mlp_tp_gather=require_mlp_tp_gather(
self.model_runner.server_args
),
disable_overlap_schedule=self.model_runner.server_args.disable_overlap_schedule,
offload_tags=set(),
)

model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.capture_hidden_mode = CaptureHiddenMode.FULL

output = self.model_runner.forward(forward_batch)
if hasattr(output, "logits_output"):
output = output.logits_output

input_lens = [len(req.origin_input_ids) for req in reqs]
if (
hasattr(output, "aux_hidden_states")
and output.aux_hidden_states is not None
):
hidden_states_list = torch.split(
output.aux_hidden_states, input_lens, dim=0
)
elif hasattr(output, "hidden_states") and output.hidden_states is not None:
hidden_states_list = torch.split(output.hidden_states, input_lens, dim=0)
else:
raise ValueError("SGLang output does not contain hidden states.")

self.model_runner.req_to_token_pool.clear()
self.model_runner.token_to_kv_pool_allocator.clear()

return hidden_states_list
super().set_capture_layers(layer_ids) # records self.capture_layer_ids
# Some target models expose set_eagle3_layers_to_capture; guard on it.
self._backend.set_eagle3_capture_layers(layer_ids, if_supported=True)

@torch.no_grad()
def generate_dflash_data(
Expand All @@ -214,29 +130,9 @@ def generate_dflash_data(
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
) -> DFlashTargetOutput:
sampling_params = SamplingParams(temperature=0, max_new_tokens=1)
reqs, data_cache = [], []

if isinstance(input_ids, torch.Tensor):
input_ids_list = torch.split(input_ids, 1, dim=0)
attn_mask_list = torch.split(attention_mask, 1, dim=0)
loss_mask_list = torch.split(loss_mask, 1, dim=0)

for idx, (curr_ids, curr_attn, curr_loss) in enumerate(
zip(input_ids_list, attn_mask_list, loss_mask_list)
):
req = Req(
rid=str(idx),
origin_input_text="",
origin_input_ids=curr_ids.view(-1).tolist(),
sampling_params=sampling_params,
)
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
data_cache.append((curr_ids, curr_attn, curr_loss))
reqs.append(req)

hidden_states_list = self._extend(reqs)
data_cache, hidden_states_list = self._backend.extend_dflash(
input_ids, attention_mask, loss_mask
)

# Stack back to batch
hidden_states = torch.cat([h.unsqueeze(0) for h in hidden_states_list], dim=0)
Expand Down
Loading
Loading