Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/disagg/run_disagg_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
)

from specforge.distributed import destroy_distributed, init_distributed
from specforge.launch import build_disagg_offline_runtime, build_offline_runtime
from specforge.optimizer import BF16Optimizer
from specforge.runtime.data_plane.disagg_ingest import (
ingest_offline_features,
Expand All @@ -70,7 +71,6 @@
from specforge.runtime.data_plane.disaggregated import AuthPolicy, SharedDirFeatureStore
from specforge.runtime.data_plane.feature_store import FeatureStore
from specforge.runtime.data_plane.mooncake_store import MooncakeFeatureStore
from specforge.runtime.launch import build_disagg_offline_runtime, build_offline_runtime

RUN_ID = "eagle3-disagg"

Expand Down
4 changes: 2 additions & 2 deletions scripts/train_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
from specforge.core.dflash import OnlineDFlashModel
from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders
from specforge.distributed import destroy_distributed, get_dp_group, init_distributed
from specforge.modeling.draft.dflash import DFlashDraftModel
from specforge.modeling.target.dflash_target_model import (
from specforge.inference.target_engine.dflash_target_model import (
DFlashTargetModel,
get_dflash_target_model,
)
from specforge.modeling.draft.dflash import DFlashDraftModel
from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead
from specforge.optimizer import BF16Optimizer
from specforge.tracker import create_tracker
Expand Down
6 changes: 3 additions & 3 deletions scripts/train_domino.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
from specforge.core.domino import OnlineDominoModel
from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders
from specforge.distributed import destroy_distributed, get_dp_group, init_distributed
from specforge.modeling.draft.dflash import DFlashDraftModel
from specforge.modeling.target.dflash_target_model import (
from specforge.inference.target_engine.dflash_target_model import (
DFlashTargetModel,
get_dflash_target_model,
)
from specforge.modeling.draft.dflash import DFlashDraftModel
from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead
from specforge.optimizer import BF16Optimizer
from specforge.tracker import create_tracker
Expand Down Expand Up @@ -431,7 +431,7 @@ def get_lambda_base(
) -> float:
# Delegates to the runtime's single source of the Domino lambda schedule so the
# standalone script and DominoTrainStrategy cannot drift.
from specforge.runtime.training.strategy import linear_lambda_base
from specforge.training.strategies.base import linear_lambda_base

return linear_lambda_base(global_step, total_steps, lambda_start, decay_ratio)

Expand Down
4 changes: 2 additions & 2 deletions scripts/train_eagle3_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def optimizer_factory(draft_module):
if online:
# `strategy=` selects the draft model (here eagle3); the runtime resolves
# its StrategySpec. The topology is the builder; the model is a parameter.
from specforge.runtime.launch import build_online_runtime
from specforge.launch import build_online_runtime

# Online target produces features in-loop (any backend exposing
# generate_eagle3_data — HF or SGLang). is_online=True returns the model.
Expand Down Expand Up @@ -171,7 +171,7 @@ def optimizer_factory(draft_module):
print(f"[online] rollout produced {produced} samples", flush=True)
trainer.fit(loader)
else:
from specforge.runtime.launch import build_offline_runtime
from specforge.launch import build_offline_runtime

target_head, _ = build_target_model(args, draft_config, is_online=False)
trainer, loader = build_offline_runtime(
Expand Down
2 changes: 1 addition & 1 deletion specforge/eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch.distributed as dist

from specforge.runtime.contracts import TrainBatch
from specforge.runtime.training.strategy import StepOutput
from specforge.training.strategies.base import StepOutput


class Evaluator:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
"""Inference / rollout plane: SGLang adapter, capture config, rollout worker.
"""Inference / rollout plane: rollout worker, capture config, adapters, target engines.

Submodules import the SpecForge model / SGLang code, so they are imported
explicitly by callers rather than at package load.
Expand Down
6 changes: 6 additions & 0 deletions specforge/inference/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# coding=utf-8
"""FeatureSource adapters: per-strategy capture over a TargetEngine.

``eagle3.SGLangAdapter`` (default) and ``dflash.DFlashAdapter`` implement the
``rollout_worker.FeatureSource`` protocol.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@

import torch

from specforge.inference.capture import CaptureConfig
from specforge.runtime.contracts import PromptTask
from specforge.runtime.inference.capture import CaptureConfig


def _as_2d_long(values, device) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

import torch

from specforge.inference.capture import CaptureConfig
from specforge.runtime.contracts import PromptTask
from specforge.runtime.inference.capture import CaptureConfig


def _as_2d_long(values, device) -> torch.Tensor:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@

from typing import Any, Dict, List, Optional, Protocol

from specforge.runtime.contracts import PromptTask, SampleRef
from specforge.runtime.inference.capture import (
from specforge.inference.capture import (
CaptureConfig,
CaptureMismatchError,
verify_capture,
)
from specforge.runtime.contracts import PromptTask, SampleRef

# health states: a worker REPORTS health; the controller decides scheduling.
HEALTH_STATES = ("starting", "ready", "paused", "draining", "unhealthy", "stopped")
Expand Down
40 changes: 40 additions & 0 deletions specforge/inference/target_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# coding=utf-8
"""Target engines: the backend-agnostic capture surface (TargetEngine + factory)."""

from .base import KNOWN_BACKENDS, TargetEngine
from .eagle3_target_model import (
CustomEagle3TargetEngine,
CustomEagle3TargetModel,
Eagle3TargetEngine,
Eagle3TargetModel,
HFEagle3TargetEngine,
HFEagle3TargetModel,
SGLangEagle3TargetEngine,
SGLangEagle3TargetModel,
SGLangServerEagle3TargetEngine,
get_eagle3_target_model,
)
from .factory import available_target_engines, get_target_engine

__all__ = [
"TargetEngine",
"KNOWN_BACKENDS",
"get_target_engine",
"available_target_engines",
"Eagle3TargetEngine",
"SGLangEagle3TargetEngine",
"HFEagle3TargetEngine",
"CustomEagle3TargetEngine",
"SGLangServerEagle3TargetEngine",
"get_eagle3_target_model",
# Back-compat aliases (pre-Phase-B names)
"Eagle3TargetModel",
"SGLangEagle3TargetModel",
"HFEagle3TargetModel",
"CustomEagle3TargetModel",
]

# NOTE: the DFlash engines (dflash_target_model) are intentionally NOT eagerly
# imported here — that module imports sglang internals unconditionally, and this
# package must stay importable without the pinned sglang. Import them from the
# submodule, or via get_target_engine(strategy="dflash", ...).
16 changes: 8 additions & 8 deletions specforge/runtime/launch.py → specforge/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

The named builders span the topology axis (offline vs online, colocated vs
disaggregated); the draft model is the ``strategy=`` parameter, resolved to a
:class:`StrategySpec` (``specforge.runtime.training.registry``) — adding a model
:class:`StrategySpec` (``specforge.training.strategies.registry``) — adding a model
is a registry entry, not a new ``build_*`` family.
"""

Expand All @@ -29,7 +29,7 @@
SQLiteMetadataStore,
)
from specforge.runtime.data_plane import FeatureStore, LocalFeatureStore
from specforge.runtime.training.registry import StrategySpec, resolve_strategy
from specforge.training.strategies.registry import StrategySpec, resolve_strategy

# ---------------------------------------------------------------------------
# Shared assemblers — strategy- and topology-agnostic.
Expand Down Expand Up @@ -109,7 +109,7 @@ def _offline_io(spec: StrategySpec, max_len: int):
f"offline data path for strategy {spec.name!r} is not wired yet: its "
f"StrategySpec needs make_offline_transform + make_offline_collate "
f"(DFlash/Domino use their own feature schema). See "
f"specforge.runtime.training.registry."
f"specforge.training.strategies.registry."
)
return spec.make_offline_collate(), spec.make_offline_transform(max_len)

Expand All @@ -122,7 +122,7 @@ def _online_collate(spec: StrategySpec, collate_fn):
raise NotImplementedError(
f"online data path for strategy {spec.name!r} is not wired yet: its "
f"StrategySpec needs make_online_collate (or pass an explicit collate_fn). "
f"See specforge.runtime.training.registry."
f"See specforge.training.strategies.registry."
)
return spec.make_online_collate()

Expand Down Expand Up @@ -190,8 +190,8 @@ def _assemble_rollout_workers(
f"a {spec.name} capture adapter. Set make_adapter + supports_online=True "
f"on its StrategySpec."
)
from specforge.runtime.inference.capture import CaptureConfig
from specforge.runtime.inference.rollout_worker import RolloutWorker
from specforge.inference.capture import CaptureConfig
from specforge.inference.rollout_worker import RolloutWorker

if aux_hidden_state_layer_ids is None:
aux_hidden_state_layer_ids = tuple(
Expand All @@ -200,7 +200,7 @@ def _assemble_rollout_workers(
if spec.make_adapter is not None:
adapter = spec.make_adapter(target_model, device=device, t2d=t2d)
else:
from specforge.runtime.inference.sglang_adapter import SGLangAdapter
from specforge.inference.adapters.eagle3 import SGLangAdapter

adapter = SGLangAdapter(target_model, device=device, t2d=t2d)
capture = CaptureConfig.from_strategy(
Expand Down Expand Up @@ -272,7 +272,7 @@ def build_offline_runtime(
raise NotImplementedError(
f"offline data path for strategy {spec.name!r} is not wired yet: its "
f"StrategySpec needs make_offline_reader. See "
f"specforge.runtime.training.registry."
f"specforge.training.strategies.registry."
)
collate_fn, per_sample_transform = _offline_io(spec, max_len)
controller, durable_ack = build_control_plane_for_mode(
Expand Down
20 changes: 12 additions & 8 deletions specforge/modeling/target/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .base import KNOWN_BACKENDS, TargetEngine
from .eagle3_target_model import (
from specforge.inference.target_engine.base import KNOWN_BACKENDS, TargetEngine
from specforge.inference.target_engine.eagle3_target_model import (
CustomEagle3TargetEngine,
CustomEagle3TargetModel,
Eagle3TargetEngine,
Expand All @@ -11,7 +11,11 @@
SGLangServerEagle3TargetEngine,
get_eagle3_target_model,
)
from .factory import available_target_engines, get_target_engine
from specforge.inference.target_engine.factory import (
available_target_engines,
get_target_engine,
)

from .target_head import TargetHead

__all__ = [
Expand All @@ -35,8 +39,8 @@
"TargetHead",
]

# NOTE: the DFlash engines (dflash_target_model) are intentionally NOT eagerly
# imported here — that module imports sglang internals unconditionally, and this
# package must stay importable without the pinned sglang (see factory._resolve_loader
# and eagle3_target_model's module docstring). Import them from the submodule, or
# via get_target_engine(strategy="dflash", ...).
# NOTE: the DFlash engines are intentionally NOT eagerly imported here — that
# module imports sglang internals unconditionally, and this package must stay
# importable without the pinned sglang. Import them from
# specforge.inference.target_engine.dflash_target_model, or via
# get_target_engine(strategy="dflash", ...).
7 changes: 0 additions & 7 deletions specforge/runtime/training/__init__.py

This file was deleted.

File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import torch

from specforge.runtime.contracts import TrainBatch
from specforge.runtime.training.backend import TrainingBackend
from specforge.runtime.training.strategy import (
from specforge.training.backend import TrainingBackend
from specforge.training.strategies.base import (
DraftTrainStrategy,
StepContext,
StepOutput,
Expand Down
6 changes: 6 additions & 0 deletions specforge/training/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# coding=utf-8
"""Per-algorithm training strategies + the StrategySpec registry.

Submodules import the SpecForge model code, so they are imported explicitly by
callers rather than at package load.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

import torch

from specforge.runtime.training.strategy import DraftTrainStrategy, Eagle3TrainStrategy
from specforge.training.strategies.base import DraftTrainStrategy, Eagle3TrainStrategy


def concat_collate(feats: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
Expand Down Expand Up @@ -162,7 +162,7 @@ def _eagle3_offline_collate():
# and emits the same schema. The DFlashTrainStrategy already drops into the
# unchanged TrainerCore/Backend/Loader.

from specforge.runtime.training.strategy import DFlashTrainStrategy
from specforge.training.strategies.base import DFlashTrainStrategy


def _dflash_offline_reader(hidden_states_path, *, run_id, ttt_length, max_len):
Expand Down Expand Up @@ -244,7 +244,7 @@ def pad3d(t): # [1, n, W] -> [1, maxlen, W]


def _dflash_adapter(target_model, *, device="cuda", t2d=None):
from specforge.runtime.inference.dflash_adapter import DFlashAdapter
from specforge.inference.adapters.dflash import DFlashAdapter

return DFlashAdapter(target_model, device=device, t2d=t2d)

Expand Down Expand Up @@ -273,7 +273,7 @@ def _dflash_adapter(target_model, *, device="cuda", t2d=None):
# StepContext (forward_loss(batch, ctx)). That is the whole reason a new algorithm
# needs anything beyond a spec entry here.

from specforge.runtime.training.strategy import DominoTrainStrategy
from specforge.training.strategies.base import DominoTrainStrategy


def _domino_offline_reader(hidden_states_path, *, run_id, ttt_length, max_len):
Expand Down
6 changes: 3 additions & 3 deletions specforge/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from typing import Optional

from specforge.runtime.data_plane import FeatureDataLoader, FeatureStore
from specforge.runtime.training.backend import FSDPTrainingBackend, ParallelConfig
from specforge.runtime.training.registry import StrategySpec, resolve_strategy
from specforge.runtime.training.trainer import TrainerController, TrainerCore
from specforge.training.backend import FSDPTrainingBackend, ParallelConfig
from specforge.training.checkpoint import CheckpointManager
from specforge.training.controller import TrainerController, TrainerCore
from specforge.training.strategies.registry import StrategySpec, resolve_strategy


class Trainer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from accelerate.utils import set_seed

from specforge.distributed import init_distributed
from specforge.modeling.target.eagle3_target_model import SGLangEagle3TargetModel
from specforge.inference.target_engine.eagle3_target_model import (
SGLangEagle3TargetModel,
)
from tests.utils import get_available_port


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from accelerate.utils import set_seed

from specforge.distributed import init_distributed
from specforge.modeling.target.eagle3_target_model import (
from specforge.inference.target_engine.eagle3_target_model import (
CustomEagle3TargetModel,
HFEagle3TargetModel,
SGLangEagle3TargetModel,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_runtime/test_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from specforge.runtime.inference.capture import (
from specforge.inference.capture import (
CaptureConfig,
CaptureMismatchError,
verify_capture,
Expand Down
Loading
Loading