Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions specforge/inference/DESIGN.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ plane exchanges are in [`../contracts.py`](../contracts.py).

## Responsibility

The rollout/inference plane turns leased PromptTasks into per-sample feature tensors and commits only their typed SampleRef metadata to the controller — it never hands a tensor to the controller. It owns the clean boundary to the target engine (a `TargetEngine` via its generic `capture()`; the sglang-version glue lives behind `SGLangCaptureBackend`), the only place target→draft projection/pruning happens, and the loud pre-write validation (verify_capture against a typed CaptureConfig) that converts layer/name/width/target-dim mismatches into immediate, localized errors at the extraction boundary instead of downstream trainer bugs.
The rollout/inference plane turns leased PromptTasks into per-sample feature tensors and commits only their typed SampleRef metadata to the controller — it never hands a tensor to the controller. It owns the clean boundary to the target engine (a `TargetEngine` via its generic `capture()`; the sglang-version glue lives behind `SGLangCaptureBackend`), the only place target→draft projection/pruning happens, and the loud pre-write validation (`verify_feature_contract` against a typed `FeatureContract`) that converts layer/name/width/target-dim mismatches into immediate, localized errors at the extraction boundary instead of downstream trainer bugs.

## Internal mechanics

Expand All @@ -18,10 +18,10 @@ flowchart TD
classDef data fill:#fdeede,stroke:#d6893b,color:#6b3a0b;

A[lease_prompt_tasks] --> B[generate_features per batch]
B --> C[SGLangAdapter group by len single forward]
B --> C[PolicyFeatureAdapter group by len single forward]
C --> D[TargetEngine.capture]
D --> E[_project_target logits or pruned_logits]
E --> F[verify_capture vs CaptureConfig]
E --> F[verify_feature_contract vs FeatureContract]
F -->|mismatch| G[fail_prompt_tasks non retryable]
F -->|ok| H[FeatureStore put]
H -->|put error| I[abort then fail retryable]
Expand All @@ -34,7 +34,7 @@ flowchart TD
class G control;
```

The rollout plane turns leased `PromptTask`s into per-sample feature tensors and commits only their typed `SampleRef` metadata — it never hands a tensor to the controller. `RolloutWorker.run_once` is the core loop: lease up to `max_tasks`, call `feature_source.generate_features(tasks, capture=...)` once for the whole batch, enforce a strict `len(feats)==len(tasks)` contract, then per sample pop the out-of-band `__aux_layer_ids__`, run `verify_capture`, and on success `FeatureStore.put` (tensors go straight to the data plane). Every leased task ends in exactly one terminal controller action — `commit_samples` on success or `fail_prompt_tasks` on generate failure / wrong count / capture mismatch / put failure — with `sample_id = f"{run_id}:{task.task_id}"` deterministic and a put exception triggering `abort`. `CaptureConfig` is a frozen, strategy-derived contract carrying `feature_names`, `aux_hidden_state_layer_ids`, `target_repr`, and the derived `expected_aux_width` / `expected_target_dim()`. `verify_capture` is the loud pre-`put` validator: it checks name presence, aux-layer-id equality, aux last-dim width, and target last-dim, gating `pruned_logits` on a non-None `vocab_map_version`, raising `CaptureMismatchError` at the boundary. `SGLangAdapter` is the only place target to draft projection happens (`_project_target`: passthrough for `logits`, `t2d`-indexing for `pruned_logits`), batching equal-length tasks into one padding-free `TargetEngine.capture` forward and slicing rows back into original task order.
The rollout plane turns leased `PromptTask`s into per-sample feature tensors and commits only their typed `SampleRef` metadata — it never hands a tensor to the controller. `RolloutWorker.run_once` is the core loop: lease up to `max_tasks`, call `feature_source.generate_features(tasks, capture=...)` once for the whole batch, enforce a strict `len(feats)==len(tasks)` contract, then per sample pop the out-of-band `__aux_layer_ids__`, run `verify_feature_contract`, and on success `FeatureStore.put` (tensors go straight to the data plane). Every leased task ends in exactly one terminal controller action — `commit_samples` on success or `fail_prompt_tasks` on generate failure / wrong count / contract mismatch / put failure — with `sample_id = f"{run_id}:{task.task_id}"` deterministic and a put exception triggering `abort`. `FeatureContract` is a frozen, strategy-derived contract carrying `feature_names`, `aux_hidden_state_layer_ids`, `target_repr`, and the derived `expected_aux_width` / `expected_target_dim()`. `verify_feature_contract` is the loud pre-`put` validator: it checks name presence, aux-layer-id equality, aux last-dim width, and target last-dim, gating `pruned_logits` on a non-None `vocab_map_version`, raising `FeatureContractError` at the boundary. `PolicyFeatureAdapter` is the ONE runtime adapter (per-strategy `FeatureSchema` decides the emitted dict; `SGLangAdapter` / `DFlashAdapter` are thin schema-pinning subclasses) and the only place target to draft projection happens (`_project_target`: passthrough for `logits`, `t2d`-indexing for `pruned_logits`), batching equal-length tasks into one padding-free `TargetEngine.capture` forward — which must return a typed `TargetCaptureBatch` — and slicing rows back into original task order.

## Endpoints

Expand All @@ -44,8 +44,8 @@ The rollout plane turns leased `PromptTask`s into per-sample feature tensors and
|---|---|---|
| `RolloutWorker` | `DataFlowController.register_rollout_worker` | control |
| `RolloutWorker` | `DataFlowController.lease_prompt_tasks` | control |
| `RolloutWorker` | `SGLangAdapter.generate_features` | compute |
| `SGLangAdapter` | `TargetEngine.capture` (→ `SGLangCaptureBackend`) | compute |
| `RolloutWorker` | `PolicyFeatureAdapter.generate_features` | compute |
| `PolicyFeatureAdapter` | `TargetEngine.capture` (→ `SGLangCaptureBackend`) | compute |
| `RolloutWorker` | `FeatureStore.put` | data |
| `RolloutWorker` | `FeatureStore.abort` | data |
| `RolloutWorker` | `DataFlowController.commit_samples` | control |
Expand Down
2 changes: 1 addition & 1 deletion specforge/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
"""Inference / rollout plane: rollout worker, capture config, adapters, target engines.
"""Inference / rollout plane: rollout worker, feature contract, adapters, target engines.

Submodules import the SpecForge model / SGLang code, so they are imported
explicitly by callers rather than at package load.
Expand Down
9 changes: 6 additions & 3 deletions specforge/inference/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# coding=utf-8
"""FeatureSource adapters: per-strategy capture over a TargetEngine.
"""FeatureSource adapters: schema-parameterized capture over a TargetEngine.

``eagle3.SGLangAdapter`` (default) and ``dflash.DFlashAdapter`` implement the
``rollout_worker.FeatureSource`` protocol.
``policy.PolicyFeatureAdapter`` is the single runtime adapter (length grouping,
batched capture, per-sample slicing, vocab projection); each strategy registers
a ``FeatureSchema`` for the store-ready dict shape. ``eagle3.SGLangAdapter``
and ``dflash.DFlashAdapter`` are thin schema-pinning subclasses kept for
back-compat; all implement the ``rollout_worker.FeatureSource`` protocol.
"""
127 changes: 33 additions & 94 deletions specforge/inference/adapters/dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,41 @@
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
"""DFlashAdapter: the DFlash counterpart of SGLangAdapter.

Wraps a DFlash ``TargetEngine`` (sglang / hf), calling its generic ``capture(...)``
(the legacy ``generate_dflash_data`` is kept as a back-compat alias), and returns
per-sample feature dicts for the DataFlow rollout. DFlash's schema is
``{input_ids, hidden_states, loss_mask}`` — note ``hidden_states`` is the
concatenated target capture layers, and there is NO ``target`` distribution / no
vocab projection (DFlash trains on hard real-token labels), so unlike
``SGLangAdapter`` there is no ``_project_target`` / ``t2d`` step.

``verify_capture`` (run by the RolloutWorker before any store write) keys its
eagle3-specific aux/target checks on the feature names ``"hidden_state"`` /
``"target"``, which DFlash does not emit, so those checks self-skip; the
recorded-aux-layer check is skipped too because the RolloutWorker reads it via
``feats.pop("__aux_layer_ids__", None)`` and DFlash simply omits the key.

Imports SpecForge model code transitively (via the target backend), so it is
imported by rollout entry points, not at package load.
"""DFlashAdapter: the DFlash schema pinned onto PolicyFeatureAdapter.

DFlash's store-ready dict is ``{input_ids, hidden_states, loss_mask}`` — note
``hidden_states`` is the concatenated target capture layers, and there is NO
``target`` distribution / no vocab projection (DFlash trains on hard real-token
labels), so ``DFLASH_FEATURE_SCHEMA`` sets ``target_feature=None`` and skips the
``t2d`` projection step entirely.

``verify_feature_contract`` (run by the RolloutWorker before any store write)
keys its eagle3-specific aux/target checks on the feature names
``"hidden_state"`` / ``"target"``, which DFlash does not emit, so those checks
self-skip; the recorded-aux-layer check is skipped too because the schema does
not emit ``__aux_layer_ids__`` and the RolloutWorker reads it via
``feats.pop("__aux_layer_ids__", None)``.

Kept as a named class for back-compat; the shared runtime conversion lives in
:class:`~specforge.inference.adapters.policy.PolicyFeatureAdapter`.
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional
from typing import Optional

import torch

from specforge.inference.capture import CaptureConfig
from specforge.runtime.contracts import PromptTask


def _as_2d_long(values, device) -> torch.Tensor:
t = torch.as_tensor(values, dtype=torch.long, device=device)
if t.ndim == 1:
t = t.unsqueeze(0)
return t
from specforge.inference.adapters.policy import (
DFLASH_FEATURE_SCHEMA,
PolicyFeatureAdapter,
)


class DFlashAdapter:
"""Adapter over a SpecForge DFlash ``TargetEngine`` (via its generic ``capture()``)."""
class DFlashAdapter(PolicyFeatureAdapter):
"""DFlash ``FeatureSource`` over a ``TargetEngine`` (via its generic ``capture()``)."""

SUPPORTED_FEATURE_NAMES = {"input_ids", "hidden_states", "loss_mask"}
SUPPORTED_FEATURE_NAMES = DFLASH_FEATURE_SCHEMA.names

def __init__(
self,
Expand All @@ -55,69 +49,14 @@ def __init__(
device: str = "cuda",
t2d: Optional[torch.Tensor] = None, # unused (DFlash has no vocab map); kept
) -> None: # for a uniform make_adapter(target_model, *, device, t2d) signature
self.target_model = target_model
self.device = device
self._healthy = True

def generate_features(
self, tasks: List[PromptTask], *, capture: CaptureConfig
) -> List[Dict[str, Any]]:
"""Extract per-sample DFlash features, batching equal-length prompts.

Mirrors SGLangAdapter's length-grouped batching, but calls the engine's
generic ``capture(...)`` and emits the DFlash schema. The target must have
had ``set_capture_layers`` called so ``hidden_states`` width matches the
draft's ``len(target_layer_ids) * hidden_size``.
"""
out: List[Optional[Dict[str, Any]]] = [None] * len(tasks)

groups: Dict[int, List[int]] = {}
for i, task in enumerate(tasks):
groups.setdefault(len(task.payload["input_ids"]), []).append(i)

for _length, idxs in groups.items():
input_ids = torch.stack(
[
_as_2d_long(tasks[i].payload["input_ids"], self.device)[0]
for i in idxs
],
dim=0,
) # (G, L)
length = input_ids.shape[1]
loss_mask = torch.stack(
[
(
_as_2d_long(tasks[i].payload["loss_mask"], self.device)[0]
if "loss_mask" in tasks[i].payload
else torch.ones(length, dtype=torch.long, device=self.device)
)
for i in idxs
],
dim=0,
)
attention_mask = torch.ones_like(input_ids)
data = self.target_model.capture(
input_ids=input_ids,
attention_mask=attention_mask,
loss_mask=loss_mask,
)
for j, gi in enumerate(idxs):
out[gi] = {
"input_ids": data.input_ids[j : j + 1],
"hidden_states": data.hidden_states[j : j + 1],
"loss_mask": data.loss_mask[j : j + 1],
# DFlash emits no eagle3 aux/target features. The recorded
# aux-layer check in verify_capture is skipped for free: the
# RolloutWorker reads it via feats.pop("__aux_layer_ids__", None),
# so an absent key is identical to an explicit None.
}
return out

def health(self) -> Dict[str, Any]:
return {
"healthy": self._healthy,
"backend": getattr(self.target_model, "backend", "unknown"),
}
super().__init__(
target_model,
schema=DFLASH_FEATURE_SCHEMA,
device=device,
t2d=t2d,
# DFlash engines' capture() does not take shard_returns; never pass it.
shard_returns=None,
)


__all__ = ["DFlashAdapter"]
Loading
Loading