Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions specforge/runtime/ARCHITECTURE.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ its own design note (see "Per-plane internals" below).

## End-to-end flow

**Online:** `RolloutWorker.run_once` leases prompts (`lease_prompt_tasks`), calls `generate_features` (which drives `generate_eagle3_data`), runs `verify_capture`, then `FeatureStore.put` writes tensors **directly to the data plane**. Only the resulting `SampleRef` metadata goes to the controller via `commit_samples`, which dedups through `MetadataStore.commit_sample` and enqueues fresh refs onto `SampleRefQueue`.
**Online:** `RolloutWorker.run_once` leases prompts (`lease_prompt_tasks`), calls `generate_features` (which drives the engine's generic `TargetEngine.capture`), runs `verify_capture`, then `FeatureStore.put` writes tensors **directly to the data plane**. Only the resulting `SampleRef` metadata goes to the controller via `commit_samples`, which dedups through `MetadataStore.commit_sample` and enqueues fresh refs onto `SampleRefQueue`.

**Offline:** `OfflineManifestReader.read()` emits in-place `file://` `SampleRef`s (no tensor copy) and the launcher calls `enqueue_offline_refs`, which dedups and enqueues onto the **same** `SampleRefQueue`.

Expand All @@ -36,7 +36,7 @@ flowchart TD
subgraph COMPUTE[compute autonomous loops]
RW[RolloutWorker run_once loop]
SG[SGLangAdapter generate_features]
TGT[Eagle3TargetModel generate_eagle3_data]
TGT[TargetEngine capture via SGLangCaptureBackend]
TR[TrainerController fit loop]
CORE[TrainerCore train_step]
STRAT[Eagle3TrainStrategy forward_loss]
Expand All @@ -59,7 +59,7 @@ flowchart TD
RW -->|register_rollout_worker| CTRL
RW -->|lease_prompt_tasks| CTRL
RW -->|generate_features| SG
SG -->|generate_eagle3_data| TGT
SG -->|capture| TGT
RW -.->|put| STORE
RW -->|commit_samples| CTRL
RW -->|fail_prompt_tasks| CTRL
Expand Down Expand Up @@ -97,7 +97,7 @@ flowchart TD
| RolloutWorker | register_rollout_worker | control | Register worker, obtain authoritative worker_id (no-tensor guard on info) |
| RolloutWorker | lease_prompt_tasks | control | Pop up to max_tasks pending PromptTasks, mark leased to worker_id |
| RolloutWorker | generate_features | compute | Ask the FeatureSource (SGLangAdapter) for one feature dict per task |
| SGLangAdapter | generate_eagle3_data | compute | Run the target engine's batched forward to extract hidden_states/target |
| SGLangAdapter | TargetEngine.capture | compute | Run the target engine's batched forward (sglang glue in SGLangCaptureBackend) to extract hidden_states/target |
| RolloutWorker | put | data | Persist verified feature tensors directly to FeatureStore, get back a SampleRef |
| RolloutWorker | abort | data | Clean up a partial/failed write so no corrupt sample is left |
| RolloutWorker | commit_samples | control | Commit metadata-only SampleRefs; dedup + enqueue fresh refs |
Expand Down
8 changes: 4 additions & 4 deletions specforge/runtime/inference/DESIGN.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ plane exchanges are in [`../contracts.py`](../contracts.py).

## Responsibility

The rollout/inference plane turns leased PromptTasks into per-sample feature tensors and commits only their typed SampleRef metadata to the controller — it never hands a tensor to the controller. It owns the clean boundary to the target engine (Eagle3TargetModel via generate_eagle3_data), the only place target→draft projection/pruning happens, and the loud pre-write validation (verify_capture against a typed CaptureConfig) that converts layer/name/width/target-dim mismatches into immediate, localized errors at the extraction boundary instead of downstream trainer bugs.
The rollout/inference plane turns leased PromptTasks into per-sample feature tensors and commits only their typed SampleRef metadata to the controller — it never hands a tensor to the controller. It owns the clean boundary to the target engine (a `TargetEngine` via its generic `capture()`; the sglang-version glue lives behind `SGLangCaptureBackend`), the only place target→draft projection/pruning happens, and the loud pre-write validation (verify_capture against a typed CaptureConfig) that converts layer/name/width/target-dim mismatches into immediate, localized errors at the extraction boundary instead of downstream trainer bugs.

## Internal mechanics

Expand All @@ -19,7 +19,7 @@ flowchart TD

A[lease_prompt_tasks] --> B[generate_features per batch]
B --> C[SGLangAdapter group by len single forward]
C --> D[generate_eagle3_data]
C --> D[TargetEngine.capture]
D --> E[_project_target logits or pruned_logits]
E --> F[verify_capture vs CaptureConfig]
F -->|mismatch| G[fail_prompt_tasks non retryable]
Expand All @@ -34,7 +34,7 @@ flowchart TD
class G control;
```

The rollout plane turns leased `PromptTask`s into per-sample feature tensors and commits only their typed `SampleRef` metadata — it never hands a tensor to the controller. `RolloutWorker.run_once` is the core loop: lease up to `max_tasks`, call `feature_source.generate_features(tasks, capture=...)` once for the whole batch, enforce a strict `len(feats)==len(tasks)` contract, then per sample pop the out-of-band `__aux_layer_ids__`, run `verify_capture`, and on success `FeatureStore.put` (tensors go straight to the data plane). Every leased task ends in exactly one terminal controller action — `commit_samples` on success or `fail_prompt_tasks` on generate failure / wrong count / capture mismatch / put failure — with `sample_id = f"{run_id}:{task.task_id}"` deterministic and a put exception triggering `abort`. `CaptureConfig` is a frozen, strategy-derived contract carrying `feature_names`, `aux_hidden_state_layer_ids`, `target_repr`, and the derived `expected_aux_width` / `expected_target_dim()`. `verify_capture` is the loud pre-`put` validator: it checks name presence, aux-layer-id equality, aux last-dim width, and target last-dim, gating `pruned_logits` on a non-None `vocab_map_version`, raising `CaptureMismatchError` at the boundary. `SGLangAdapter` is the only place target to draft projection happens (`_project_target`: passthrough for `logits`, `t2d`-indexing for `pruned_logits`), batching equal-length tasks into one padding-free `generate_eagle3_data` forward and slicing rows back into original task order.
The rollout plane turns leased `PromptTask`s into per-sample feature tensors and commits only their typed `SampleRef` metadata — it never hands a tensor to the controller. `RolloutWorker.run_once` is the core loop: lease up to `max_tasks`, call `feature_source.generate_features(tasks, capture=...)` once for the whole batch, enforce a strict `len(feats)==len(tasks)` contract, then per sample pop the out-of-band `__aux_layer_ids__`, run `verify_capture`, and on success `FeatureStore.put` (tensors go straight to the data plane). Every leased task ends in exactly one terminal controller action — `commit_samples` on success or `fail_prompt_tasks` on generate failure / wrong count / capture mismatch / put failure — with `sample_id = f"{run_id}:{task.task_id}"` deterministic and a put exception triggering `abort`. `CaptureConfig` is a frozen, strategy-derived contract carrying `feature_names`, `aux_hidden_state_layer_ids`, `target_repr`, and the derived `expected_aux_width` / `expected_target_dim()`. `verify_capture` is the loud pre-`put` validator: it checks name presence, aux-layer-id equality, aux last-dim width, and target last-dim, gating `pruned_logits` on a non-None `vocab_map_version`, raising `CaptureMismatchError` at the boundary. `SGLangAdapter` is the only place target to draft projection happens (`_project_target`: passthrough for `logits`, `t2d`-indexing for `pruned_logits`), batching equal-length tasks into one padding-free `TargetEngine.capture` forward and slicing rows back into original task order.

## Endpoints

Expand All @@ -45,7 +45,7 @@ The rollout plane turns leased `PromptTask`s into per-sample feature tensors and
| `RolloutWorker` | `DataFlowController.register_rollout_worker` | control |
| `RolloutWorker` | `DataFlowController.lease_prompt_tasks` | control |
| `RolloutWorker` | `SGLangAdapter.generate_features` | compute |
| `SGLangAdapter` | `Eagle3TargetModel.generate_eagle3_data` | compute |
| `SGLangAdapter` | `TargetEngine.capture` (→ `SGLangCaptureBackend`) | compute |
| `RolloutWorker` | `FeatureStore.put` | data |
| `RolloutWorker` | `FeatureStore.abort` | data |
| `RolloutWorker` | `DataFlowController.commit_samples` | control |
Expand Down
13 changes: 7 additions & 6 deletions specforge/runtime/inference/dflash_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
# http://www.apache.org/licenses/LICENSE-2.0
"""DFlashAdapter: the DFlash counterpart of SGLangAdapter.

Wraps a ``DFlashTargetModel`` (sglang / hf, both expose ``generate_dflash_data``)
and returns per-sample feature dicts for the DataFlow rollout. DFlash's schema is
Wraps a DFlash ``TargetEngine`` (sglang / hf), calling its generic ``capture(...)``
(the legacy ``generate_dflash_data`` is kept as a back-compat alias), and returns
per-sample feature dicts for the DataFlow rollout. DFlash's schema is
``{input_ids, hidden_states, loss_mask}`` — note ``hidden_states`` is the
concatenated target capture layers, and there is NO ``target`` distribution / no
vocab projection (DFlash trains on hard real-token labels), so unlike
Expand Down Expand Up @@ -43,7 +44,7 @@ def _as_2d_long(values, device) -> torch.Tensor:


class DFlashAdapter:
"""Adapter over a SpecForge ``DFlashTargetModel`` (any ``generate_dflash_data``)."""
"""Adapter over a SpecForge DFlash ``TargetEngine`` (via its generic ``capture()``)."""

SUPPORTED_FEATURE_NAMES = {"input_ids", "hidden_states", "loss_mask"}

Expand All @@ -63,8 +64,8 @@ def generate_features(
) -> List[Dict[str, Any]]:
"""Extract per-sample DFlash features, batching equal-length prompts.

Mirrors SGLangAdapter's length-grouped batching, but calls
``generate_dflash_data`` and emits the DFlash schema. The target must have
Mirrors SGLangAdapter's length-grouped batching, but calls the engine's
generic ``capture(...)`` and emits the DFlash schema. The target must have
had ``set_capture_layers`` called so ``hidden_states`` width matches the
draft's ``len(target_layer_ids) * hidden_size``.
"""
Expand Down Expand Up @@ -95,7 +96,7 @@ def generate_features(
dim=0,
)
attention_mask = torch.ones_like(input_ids)
data = self.target_model.generate_dflash_data(
data = self.target_model.capture(
input_ids=input_ids,
attention_mask=attention_mask,
loss_mask=loss_mask,
Expand Down
15 changes: 8 additions & 7 deletions specforge/runtime/inference/sglang_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

``generate_features(tasks, *, capture)`` is the single extraction entry point.
``capture`` is the typed :class:`CaptureConfig` derived from the active strategy,
not an untyped dict. The adapter wraps the existing ``Eagle3TargetModel`` (sglang
/ hf / custom backends all expose ``generate_eagle3_data``), records the exact
aux-layer IDs it captured, applies the target→draft projection demanded by
not an untyped dict. The adapter wraps an EAGLE3 ``TargetEngine`` (sglang / hf /
custom backends), calling its generic ``capture(...)`` (the de-EAGLE3'd boundary;
the legacy ``generate_eagle3_data`` is kept as a back-compat alias), records the
exact aux-layer IDs it captured, applies the target→draft projection demanded by
``capture.target_repr`` (the only place pruning happens), and returns per-sample
feature dicts. The RolloutWorker then runs :func:`verify_capture` before any
store write, so a layer/name/width mismatch fails loudly at this boundary rather
Expand All @@ -39,7 +40,7 @@ def _as_2d_long(values, device) -> torch.Tensor:


class SGLangAdapter:
"""Adapter over a SpecForge ``Eagle3TargetModel`` (or any ``generate_eagle3_data``)."""
"""Adapter over a SpecForge EAGLE3 ``TargetEngine`` (via its generic ``capture()``)."""

SUPPORTED_FEATURE_NAMES = {
"input_ids",
Expand Down Expand Up @@ -94,8 +95,8 @@ def generate_features(
"""Extract per-sample features, batching the engine call.

Tasks are grouped by sequence length and each group is run through
``generate_eagle3_data`` in ONE batched forward (the engine's native
batching), instead of a per-sample loop that would serialize N forwards.
the engine's generic ``capture(...)`` in ONE batched forward (the engine's
native batching), instead of a per-sample loop that would serialize N forwards.
Equal-length grouping avoids intra-batch padding, so per-sample features
are sliced out cleanly. The result preserves task order.
"""
Expand Down Expand Up @@ -125,7 +126,7 @@ def generate_features(
dim=0,
)
attention_mask = torch.ones_like(input_ids)
data = self.target_model.generate_eagle3_data(
data = self.target_model.capture(
input_ids=input_ids,
attention_mask=attention_mask,
loss_mask=loss_mask,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_runtime/_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def build_hf_target(workdir, hidden=H, layers=8, vocab=V, aux_layer_ids=(1, 3, 4
"""Build a tiny HF Llama target wrapped by the SpecForge HF eagle3 backend."""
from transformers import LlamaConfig, LlamaForCausalLM

from specforge.modeling.target import get_eagle3_target_model
from specforge.modeling.target import get_target_engine

cfg = LlamaConfig(
hidden_size=hidden,
Expand All @@ -170,8 +170,9 @@ def build_hf_target(workdir, hidden=H, layers=8, vocab=V, aux_layer_ids=(1, 3, 4
model = LlamaForCausalLM(cfg)
target_dir = os.path.join(workdir, "hf_target")
model.save_pretrained(target_dir)
target = get_eagle3_target_model(
pretrained_model_name_or_path=target_dir,
target = get_target_engine(
target_dir,
strategy="eagle3",
backend="hf",
torch_dtype=torch.bfloat16,
device="cuda",
Expand Down
9 changes: 4 additions & 5 deletions tests/test_runtime/test_dflash_online_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ def test_online_rollout_then_fsdp_train(self):

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from specforge.modeling.target.dflash_target_model import (
get_dflash_target_model,
)
from specforge.modeling.target import get_target_engine
from specforge.optimizer import BF16Optimizer
from specforge.runtime.contracts import assert_no_tensors
from specforge.runtime.launch import build_online_runtime
Expand All @@ -49,8 +47,9 @@ def test_online_rollout_then_fsdp_train(self):
self.assertEqual(width, HIDDEN)

# HF DFlash target (no sglang) capturing the same layers the draft expects
target = get_dflash_target_model(
pretrained_model_name_or_path=target_dir,
target = get_target_engine(
target_dir,
strategy="dflash",
backend="hf",
torch_dtype=torch.bfloat16,
device="cuda",
Expand Down
9 changes: 4 additions & 5 deletions tests/test_runtime/test_domino_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ def test_online_rollout_then_fsdp_train(self):

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from specforge.modeling.target.dflash_target_model import (
get_dflash_target_model,
)
from specforge.modeling.target import get_target_engine
from specforge.optimizer import BF16Optimizer
from specforge.runtime.launch import build_online_runtime

Expand All @@ -135,8 +133,9 @@ def test_online_rollout_then_fsdp_train(self):
self.assertEqual(width, HIDDEN)

# domino captures the same hidden_states as DFlash -> reuse the DFlash target
target = get_dflash_target_model(
pretrained_model_name_or_path=target_dir,
target = get_target_engine(
target_dir,
strategy="domino",
backend="hf",
torch_dtype=torch.bfloat16,
device="cuda",
Expand Down
Loading
Loading