Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f057907
[DataFlow runtime · M6] MooncakeFeatureStore zero-copy transport
maocheng23 Jun 28, 2026
02bb95b
fix(mooncake): reject zero-copy get_into short reads; drop dead _tens…
maocheng23 Jun 29, 2026
6338371
[DataFlow runtime · online disagg 1/n] StreamingRefChannel — cross-pr…
maocheng23 Jun 28, 2026
7cd5452
[DataFlow runtime · online disagg 2/n] build_disagg_online_{producer,…
maocheng23 Jun 28, 2026
8adec57
Merge pull request #622 from sgl-project/dataflow-up-17-online-disagg
jiapingW Jun 30, 2026
d6a6fd3
[DataFlow runtime · online] O1.1 — shared cross-process control plane
maocheng23 Jun 29, 2026
17a8770
[DataFlow runtime · online] O1.2 — named builder + interleaved async …
maocheng23 Jun 29, 2026
8faf111
[DataFlow runtime] Composable launch: StrategySpec registry + paramet…
maocheng23 Jun 30, 2026
a6b8b7d
[DataFlow runtime] DFlash end-to-end on the composable launch (offlin…
maocheng23 Jun 30, 2026
9c1c020
[DataFlow runtime] Domino end-to-end + StepContext for schedule-depen…
maocheng23 Jun 30, 2026
3838788
[DataFlow runtime] Phase B1 — TargetEngine ABC + de-EAGLE3 the target…
maocheng23 Jul 1, 2026
29817b7
[DataFlow runtime] Phase B2 — decouple the target engine from the sgl…
maocheng23 Jul 1, 2026
3c768c6
refactor: clarify sglang eagle3 capture entrypoints
maocheng23 Jul 1, 2026
1d9060e
[DataFlow runtime] Phase B3 — domain Trainer wrapping the runtime seam
maocheng23 Jul 1, 2026
e6aaead
[DataFlow runtime] Phase B4 — adopt the de-EAGLE3 surface (cutover + …
maocheng23 Jul 1, 2026
f455372
[DataFlow runtime] Phase C — colocated lightweight control plane
maocheng23 Jul 1, 2026
b8cd033
[Phase C review fixes] deployment_mode on colocated builders, leaner …
maocheng23 Jul 1, 2026
7a26453
[Phase C review fixes 2] pin build_online_runtime to local_colocated
maocheng23 Jul 1, 2026
7ed8527
Merge pull request #624 from sgl-project/dataflow-up-19-online-shared…
jiapingW Jul 2, 2026
15085dc
Merge pull request #625 from sgl-project/dataflow-up-20-online-async-…
jiapingW Jul 2, 2026
ec10f47
Merge pull request #627 from sgl-project/dataflow-up-21-composable-la…
jiapingW Jul 2, 2026
14a18ba
Merge pull request #628 from sgl-project/dataflow-up-22-dflash
jiapingW Jul 2, 2026
a4dde0b
style: apply pre-commit (black/isort/autoflake)
maocheng23 Jul 2, 2026
e3b5ce6
[Phase C review fixes] rename control-plane builder and trim comments
maocheng23 Jul 2, 2026
9ede82d
Merge pull request #629 from sgl-project/dataflow-up-23-domino
jiapingW Jul 3, 2026
e94aa3f
Merge pull request #631 from sgl-project/dataflow-up-24-target-engine
jiapingW Jul 3, 2026
ac8f878
Merge pull request #632 from sgl-project/dataflow-up-25-sglang-captur…
jiapingW Jul 3, 2026
c0234ed
Merge pull request #633 from sgl-project/dataflow-up-26-domain-trainer
jiapingW Jul 3, 2026
6f0da78
Merge pull request #635 from sgl-project/dataflow-up-27-target-engine…
jiapingW Jul 3, 2026
8fb067c
Merge pull request #636 from sgl-project/dataflow-up-28-colocated-lig…
jiapingW Jul 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions examples/disagg/run_disagg_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
manifest. No model, no trainer.
* **consumer** (training pool): waits for the manifest, builds the EAGLE3 model
exactly as the offline launcher does, and trains via
``build_disagg_eagle3_runtime`` reading features from the shared store.
``build_disagg_offline_runtime`` reading features from the shared store.

The control plane carries only ``SampleRef`` metadata across the boundary; the
feature tensors travel through the shared store. Disaggregation changes *where*
Expand Down Expand Up @@ -70,10 +70,7 @@
from specforge.runtime.data_plane.disaggregated import AuthPolicy, SharedDirFeatureStore
from specforge.runtime.data_plane.feature_store import FeatureStore
from specforge.runtime.data_plane.mooncake_store import MooncakeFeatureStore
from specforge.runtime.launch import (
build_disagg_eagle3_runtime,
build_offline_eagle3_runtime,
)
from specforge.runtime.launch import build_disagg_offline_runtime, build_offline_runtime

RUN_ID = "eagle3-disagg"

Expand Down Expand Up @@ -227,7 +224,8 @@ def run_colocated(args) -> None:
sanity_check(args)
eagle3_model, target_head, optimizer_factory = _build_model_and_optimizer(args)
print(f"[colocated] training from {args.train_hidden_states_path}", flush=True)
trainer, loader = build_offline_eagle3_runtime(
trainer, loader = build_offline_runtime(
strategy="eagle3",
hidden_states_path=args.train_hidden_states_path,
eagle3_model=eagle3_model,
target_head=target_head,
Expand Down Expand Up @@ -277,7 +275,8 @@ def run_consumer(args) -> None:
location = getattr(store, "root", f"mooncake://{store.store_id}")
print(f"[consumer] training from {len(refs)} disagg refs in {location}", flush=True)

trainer, loader = build_disagg_eagle3_runtime(
trainer, loader = build_disagg_offline_runtime(
strategy="eagle3",
feature_store=store,
refs=refs,
eagle3_model=eagle3_model,
Expand Down
10 changes: 4 additions & 6 deletions scripts/train_domino.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,13 +429,11 @@ def get_lambda_base(
lambda_start: float = 1.0,
decay_ratio: float = 0.5,
) -> float:
decay_steps = max(1, int(total_steps * decay_ratio))
progress = min(global_step / decay_steps, 1.0)
lambda_base = lambda_start * (1.0 - progress)
# Delegates to the runtime's single source of the Domino lambda schedule so the
# standalone script and DominoTrainStrategy cannot drift.
from specforge.runtime.training.strategy import linear_lambda_base

# Clamp to [0, 1].
lambda_base = max(0.0, min(1.0, lambda_base))
return lambda_base
return linear_lambda_base(global_step, total_steps, lambda_start, decay_ratio)


def main():
Expand Down
52 changes: 27 additions & 25 deletions scripts/train_eagle3_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def optimizer_factory(draft_module):
logger = lambda m, s: print(f"step {s}: {m}", flush=True)

if online:
from specforge.runtime.launch import build_online_eagle3_runtime
# `strategy=` selects the draft model (here eagle3); the runtime resolves
# its StrategySpec. The topology is the builder; the model is a parameter.
from specforge.runtime.launch import build_online_runtime

# Online target produces features in-loop (any backend exposing
# generate_eagle3_data — HF or SGLang). is_online=True returns the model.
Expand All @@ -143,37 +145,37 @@ def optimizer_factory(draft_module):

# num_epochs=1: the rollout output is a consume-once stream. Multi-epoch
# online (re-rollout each epoch) is a follow-up; one rollout pass here.
trainer, loader, workers, controller, drive_rollout = (
build_online_eagle3_runtime(
target_model=target_model,
prompts=prompts,
eagle3_model=eagle3_model,
optimizer_factory=optimizer_factory,
run_id="eagle3-online",
output_dir=args.output_dir,
target_hidden_size=hidden_size,
target_vocab_size=vocab_size,
target_repr="logits",
ttt_length=args.ttt_length,
batch_size=args.target_batch_size,
accumulation_steps=args.draft_accumulation_steps,
num_epochs=1,
max_steps=args.max_num_steps,
save_interval=args.save_interval,
tp_size=args.tp_size,
sp_ulysses_size=args.sp_ulysses_size,
sp_ring_size=args.sp_ring_size,
logger=logger,
)
trainer, loader, workers, controller, drive_rollout = build_online_runtime(
strategy="eagle3",
target_model=target_model,
prompts=prompts,
eagle3_model=eagle3_model,
optimizer_factory=optimizer_factory,
run_id="eagle3-online",
output_dir=args.output_dir,
target_hidden_size=hidden_size,
target_vocab_size=vocab_size,
target_repr="logits",
ttt_length=args.ttt_length,
batch_size=args.target_batch_size,
accumulation_steps=args.draft_accumulation_steps,
num_epochs=1,
max_steps=args.max_num_steps,
save_interval=args.save_interval,
tp_size=args.tp_size,
sp_ulysses_size=args.sp_ulysses_size,
sp_ring_size=args.sp_ring_size,
logger=logger,
)
produced = drive_rollout()
print(f"[online] rollout produced {produced} samples", flush=True)
trainer.fit(loader)
else:
from specforge.runtime.launch import build_offline_eagle3_runtime
from specforge.runtime.launch import build_offline_runtime

target_head, _ = build_target_model(args, draft_config, is_online=False)
trainer, loader = build_offline_eagle3_runtime(
trainer, loader = build_offline_runtime(
strategy="eagle3",
hidden_states_path=args.train_hidden_states_path,
eagle3_model=eagle3_model,
target_head=target_head,
Expand Down
18 changes: 16 additions & 2 deletions specforge/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
# from .auto import AutoDistributedTargetModel, AutoDraftModelConfig, AutoEagle3DraftModel
from .auto import AutoDraftModelConfig, AutoEagle3DraftModel
from .draft.llama3_eagle import LlamaForCausalLMEagle3
from .target.eagle3_target_model import (
from .target import (
CustomEagle3TargetEngine,
CustomEagle3TargetModel,
Eagle3TargetEngine,
HFEagle3TargetEngine,
HFEagle3TargetModel,
SGLangEagle3TargetEngine,
SGLangEagle3TargetModel,
TargetEngine,
get_eagle3_target_model,
get_target_engine,
)

__all__ = [
"LlamaForCausalLMEagle3",
# Generic (Phase B) surface
"TargetEngine",
"Eagle3TargetEngine",
"SGLangEagle3TargetEngine",
"HFEagle3TargetEngine",
"CustomEagle3TargetEngine",
"get_target_engine",
"get_eagle3_target_model",
# Back-compat aliases (pre-Phase-B names)
"SGLangEagle3TargetModel",
"HFEagle3TargetModel",
"CustomEagle3TargetModel",
"get_eagle3_target_model",
"AutoDraftModelConfig",
"AutoEagle3DraftModel",
]
27 changes: 26 additions & 1 deletion specforge/modeling/target/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,42 @@
from .base import KNOWN_BACKENDS, TargetEngine
from .eagle3_target_model import (
CustomEagle3TargetEngine,
CustomEagle3TargetModel,
Eagle3TargetEngine,
Eagle3TargetModel,
HFEagle3TargetEngine,
HFEagle3TargetModel,
SGLangEagle3TargetEngine,
SGLangEagle3TargetModel,
SGLangServerEagle3TargetEngine,
get_eagle3_target_model,
)
from .factory import available_target_engines, get_target_engine
from .target_head import TargetHead

__all__ = [
# Generic (Phase B) surface
"TargetEngine",
"KNOWN_BACKENDS",
"get_target_engine",
"available_target_engines",
# EAGLE3 engines
"Eagle3TargetEngine",
"SGLangEagle3TargetEngine",
"HFEagle3TargetEngine",
"CustomEagle3TargetEngine",
"SGLangServerEagle3TargetEngine",
"get_eagle3_target_model",
# Back-compat aliases (pre-Phase-B names)
"Eagle3TargetModel",
"SGLangEagle3TargetModel",
"HFEagle3TargetModel",
"CustomEagle3TargetModel",
"get_eagle3_target_model",
"TargetHead",
]

# NOTE: the DFlash engines (dflash_target_model) are intentionally NOT eagerly
# imported here — that module imports sglang internals unconditionally, and this
# package must stay importable without the pinned sglang (see factory._resolve_loader
# and eagle3_target_model's module docstring). Import them from the submodule, or
# via get_target_engine(strategy="dflash", ...).
106 changes: 106 additions & 0 deletions specforge/modeling/target/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# coding=utf-8
# Copyright 2024 The SpecForge team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
"""TargetEngine: the backend-agnostic target-model abstraction (Phase B).

This is the de-EAGLE3'd boundary extracted from the former ``Eagle3TargetModel``
ABC. A ``TargetEngine`` wraps a *frozen* target model and exposes ONE generic
extraction entry point, :meth:`capture`, plus a real ``backend`` tag. The
inference/transport split (sglang in-process / hf / custom / sglang_server) is a
*backend* axis **under** each algorithm engine, and — crucially — the
sglang-version-specific glue lives behind a replaceable capture backend
(``sglang_backend``), NOT in the algorithm engine, so a sglang bump touches one
place instead of every ``*TargetModel`` subclass.

Two sibling algorithm engines subclass this ABC:

* :class:`Eagle3TargetEngine` (``eagle3_target_model``) — EAGLE3 TTT capture
(aux hidden states + logits), keeps the EAGLE3-specific
``set_aux_hidden_states_layers`` / ``generate_eagle3_data``.
* :class:`DFlashTargetEngine` (``dflash_target_model``) — DFlash block capture
(concatenated layer hidden states, no target distribution).

The runtime inference adapters (``SGLangAdapter`` / ``DFlashAdapter``) wrap a
``TargetEngine`` and remain the ``FeatureSource`` seam to the ``RolloutWorker`` —
they are unchanged by this extraction; they call the generic engine and read the
now-real ``.backend`` tag in ``health()``.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, List, Optional

import torch

# Known transport/backend tags. ``sglang_server`` (a live frozen-target SGLang
# server, cross-node) is introduced by the sglang-capture-backend PR; its capture
# depth is gated by the O1.3 spike. The tag set is advisory (informational, used
# by adapter health + provenance), not an enum the ABC enforces.
KNOWN_BACKENDS = ("sglang", "hf", "custom", "sglang_server")


class TargetEngine(ABC):
"""Backend-agnostic frozen-target engine.

Subclasses are organised on two axes: the *algorithm* (EAGLE3 / DFlash —
the intermediate ABCs :class:`Eagle3TargetEngine` / :class:`DFlashTargetEngine`)
and the *backend/transport* (sglang / hf / custom / sglang_server — the
concrete leaf classes). Only the leaf classes are instantiable; each sets a
real :attr:`backend` tag.
"""

#: Transport tag; concrete leaf engines override this class attribute
#: ("sglang" / "hf" / "custom" / "sglang_server"). Read by the inference
#: adapters' ``health()`` and recorded as rollout provenance.
backend: str = "unknown"

@classmethod
@abstractmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
torch_dtype: Optional[torch.dtype] = None,
device: Optional[str] = None,
cache_dir: Optional[str] = None,
**kwargs,
) -> "TargetEngine":
"""Load a frozen target model for this backend."""

@abstractmethod
def capture(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: torch.Tensor,
**kwargs,
) -> Any:
"""Run the frozen target forward and extract training features.

The generic extraction entry point that replaces the EAGLE3-named
``generate_eagle3_data``. Returns a per-algorithm output dataclass
(``Eagle3TargetOutput`` / ``DFlashTargetOutput``). Algorithm engines keep
their original ``generate_*_data`` method as the concrete implementation
and as a back-compat alias; ``capture`` simply dispatches to it, so the
extraction is byte-identical to the pre-refactor path.
"""

def set_capture_layers(self, layer_ids: Optional[List[int]] = None) -> None:
"""Select which target layers' hidden states to capture.

The generic hook. EAGLE3 maps this onto its 3 aux layers
(``set_aux_hidden_states_layers``); DFlash captures an arbitrary list.
Engines that do not capture intermediate layers may leave this
unimplemented.
"""
raise NotImplementedError(
f"{type(self).__name__} does not implement set_capture_layers"
)


__all__ = ["TargetEngine", "KNOWN_BACKENDS"]
Loading