Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions specforge/runtime/inference/dflash_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# coding=utf-8
# Copyright 2024 The SpecForge team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
"""DFlashAdapter: the DFlash counterpart of SGLangAdapter.

Wraps a ``DFlashTargetModel`` (sglang / hf, both expose ``generate_dflash_data``)
and returns per-sample feature dicts for the DataFlow rollout. DFlash's schema is
``{input_ids, hidden_states, loss_mask}`` — note ``hidden_states`` is the
concatenated target capture layers, and there is NO ``target`` distribution / no
vocab projection (DFlash trains on hard real-token labels), so unlike
``SGLangAdapter`` there is no ``_project_target`` / ``t2d`` step.

``verify_capture`` (run by the RolloutWorker before any store write) keys its
eagle3-specific aux/target checks on the feature names ``"hidden_state"`` /
``"target"``, which DFlash does not emit, so those checks self-skip; the
recorded-aux-layer check is skipped too because the RolloutWorker reads it via
``feats.pop("__aux_layer_ids__", None)`` and DFlash simply omits the key.

Imports SpecForge model code transitively (via the target backend), so it is
imported by rollout entry points, not at package load.
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional

import torch

from specforge.runtime.contracts import PromptTask
from specforge.runtime.inference.capture import CaptureConfig


def _as_2d_long(values, device) -> torch.Tensor:
t = torch.as_tensor(values, dtype=torch.long, device=device)
if t.ndim == 1:
t = t.unsqueeze(0)
return t


class DFlashAdapter:
"""Adapter over a SpecForge ``DFlashTargetModel`` (any ``generate_dflash_data``)."""

SUPPORTED_FEATURE_NAMES = {"input_ids", "hidden_states", "loss_mask"}

def __init__(
self,
target_model,
*,
device: str = "cuda",
t2d: Optional[torch.Tensor] = None, # unused (DFlash has no vocab map); kept
) -> None: # for a uniform make_adapter(target_model, *, device, t2d) signature
self.target_model = target_model
self.device = device
self._healthy = True

def generate_features(
self, tasks: List[PromptTask], *, capture: CaptureConfig
) -> List[Dict[str, Any]]:
"""Extract per-sample DFlash features, batching equal-length prompts.

Mirrors SGLangAdapter's length-grouped batching, but calls
``generate_dflash_data`` and emits the DFlash schema. The target must have
had ``set_capture_layers`` called so ``hidden_states`` width matches the
draft's ``len(target_layer_ids) * hidden_size``.
"""
out: List[Optional[Dict[str, Any]]] = [None] * len(tasks)

groups: Dict[int, List[int]] = {}
for i, task in enumerate(tasks):
groups.setdefault(len(task.payload["input_ids"]), []).append(i)

for _length, idxs in groups.items():
input_ids = torch.stack(
[
_as_2d_long(tasks[i].payload["input_ids"], self.device)[0]
for i in idxs
],
dim=0,
) # (G, L)
length = input_ids.shape[1]
loss_mask = torch.stack(
[
(
_as_2d_long(tasks[i].payload["loss_mask"], self.device)[0]
if "loss_mask" in tasks[i].payload
else torch.ones(length, dtype=torch.long, device=self.device)
)
for i in idxs
],
dim=0,
)
attention_mask = torch.ones_like(input_ids)
data = self.target_model.generate_dflash_data(
input_ids=input_ids,
attention_mask=attention_mask,
loss_mask=loss_mask,
)
for j, gi in enumerate(idxs):
out[gi] = {
"input_ids": data.input_ids[j : j + 1],
"hidden_states": data.hidden_states[j : j + 1],
"loss_mask": data.loss_mask[j : j + 1],
# DFlash emits no eagle3 aux/target features. The recorded
# aux-layer check in verify_capture is skipped for free: the
# RolloutWorker reads it via feats.pop("__aux_layer_ids__", None),
# so an absent key is identical to an explicit None.
}
return out

def health(self) -> Dict[str, Any]:
return {
"healthy": self._healthy,
"backend": getattr(self.target_model, "backend", "unknown"),
}


__all__ = ["DFlashAdapter"]
112 changes: 112 additions & 0 deletions specforge/runtime/training/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,118 @@ def _eagle3_offline_collate():
)


# --- DFlash -----------------------------------------------------------------
# DFlash uses its own feature schema ('hidden_states' = the concatenated target
# capture layers, NO eagle3 aux/target swap, NO target distribution / vocab map).
# Offline: the reader reuses OfflineManifestReader with dflash feature_keys; the
# transform slices to max_len (no swap); the collate pads + emits {input_ids,
# hidden_states, loss_mask}. Online: a DFlashAdapter wraps generate_dflash_data
# and emits the same schema. The DFlashTrainStrategy already drops into the
# unchanged TrainerCore/Backend/Loader.

from specforge.runtime.training.strategy import DFlashTrainStrategy


def _dflash_offline_reader(hidden_states_path, *, run_id, ttt_length, max_len):
from specforge.runtime.data_plane.offline_reader import OfflineManifestReader

# OfflineManifestReader is schema-agnostic; only its EAGLE3 defaults
# (feature_keys/strategy/target_repr) must be overridden. DFlash has no target
# distribution, so target_repr=None (only stored in ref metadata).
return OfflineManifestReader(
hidden_states_path,
run_id=run_id,
ttt_length=ttt_length,
max_len=max_len,
strategy="dflash",
feature_keys=("input_ids", "loss_mask", "hidden_states"),
target_repr=None,
)


def _dflash_process_data(raw, max_len):
"""Offline DFlash per-sample normalization.

Slices to ``max_len`` and restores a leading batch dim. NO eagle3-style
aux->input / hidden->target swap and NO last-position loss zeroing (the online
DFlash path does neither), so offline matches online training.
"""
input_ids = raw["input_ids"][:max_len]
loss_mask = raw["loss_mask"][:max_len]
hidden_states = raw["hidden_states"]
if hidden_states.dim() == 3: # stored as [1, seq, W]; drop the saved batch dim
hidden_states = hidden_states.squeeze(0)
hidden_states = hidden_states[:max_len]
return {
"input_ids": input_ids[None, :],
"loss_mask": loss_mask[None, :],
"hidden_states": hidden_states[None, :, :],
}


def _dflash_offline_transform(max_len):
return lambda raw: _dflash_process_data(raw, max_len)


def _dflash_offline_collate():
"""Right-pad ragged samples to the batch max length and concat along batch.

DataCollatorWithPadding is eagle3-specific (hardwires attention_mask + the
hidden_state/target keys), so DFlash uses this minimal collate. loss_mask is
zero-padded, so padded positions contribute no loss. (Single-rank; no SP
sharding multiple — sequence-parallel offline DFlash is a follow-up.)
"""

def collate(feats):
maxlen = max(f["input_ids"].shape[-1] for f in feats)

def pad2d(t): # [1, n] -> [1, maxlen]
n = t.shape[-1]
if n == maxlen:
return t
return torch.cat([t, t.new_zeros(t.shape[0], maxlen - n)], dim=-1)

def pad3d(t): # [1, n, W] -> [1, maxlen, W]
n = t.shape[1]
if n == maxlen:
return t
return torch.cat(
[t, t.new_zeros(t.shape[0], maxlen - n, t.shape[2])], dim=1
)

return {
"input_ids": torch.cat([pad2d(f["input_ids"]) for f in feats], dim=0),
"loss_mask": torch.cat([pad2d(f["loss_mask"]) for f in feats], dim=0),
"hidden_states": torch.cat(
[pad3d(f["hidden_states"]) for f in feats], dim=0
),
}

return collate


def _dflash_adapter(target_model, *, device="cuda", t2d=None):
from specforge.runtime.inference.dflash_adapter import DFlashAdapter

return DFlashAdapter(target_model, device=device, t2d=t2d)


register_strategy(
StrategySpec(
name="dflash",
required_features=frozenset(DFlashTrainStrategy.required_features),
make_strategy=lambda wrapped, *, target_head=None: DFlashTrainStrategy(wrapped),
uses_target_head=False,
make_offline_reader=_dflash_offline_reader,
make_offline_transform=_dflash_offline_transform,
make_offline_collate=_dflash_offline_collate,
make_online_collate=lambda: concat_collate,
make_adapter=_dflash_adapter,
supports_online=True,
)
)


__all__ = [
"StrategySpec",
"concat_collate",
Expand Down
100 changes: 100 additions & 0 deletions tests/test_runtime/_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,103 @@ def build_eagle3(workdir, ttt=3):
draft_model=draft_model, length=ttt, attention_backend="flex_attention"
).cuda()
return eagle3_model, target_head


# --- DFlash fixtures ---------------------------------------------------------
# DFlash has its OWN feature schema: 'hidden_states' = concat of the target
# capture layers (NO eagle3 aux/target swap, NO target distribution). For a
# single draft layer the capture set is one layer, so width == hidden_size.


def write_offline_files_dflash(d, n=4, seq=32, hidden=H, vocab=V, seed=0):
"""Write synthetic DFlash offline .ckpt files: {input_ids, loss_mask, hidden_states}.

No production dumper for DFlash exists yet (prepare_hidden_states.py is the
EAGLE3 dumper), so the offline DataFlow path is exercised with synthetic files.
loss_mask is all-ones with seq >= 2*block_size so anchor sampling succeeds;
hidden_states is bf16 (uniform dtype across files for the loader's spec check).
"""
os.makedirs(d, exist_ok=True)
g = torch.Generator().manual_seed(seed)
for i in range(n):
torch.save(
{
"input_ids": torch.randint(0, vocab, (seq,), generator=g),
"loss_mask": torch.ones(seq, dtype=torch.long),
# width == len(target_layer_ids)*hidden; single draft layer -> hidden
"hidden_states": torch.randn(1, seq, hidden, generator=g).to(
torch.bfloat16
),
},
os.path.join(d, f"{i:04d}.ckpt"),
)
return d


def build_dflash(
workdir,
*,
hidden=H,
vocab=V,
target_layers=4,
draft_layers=1,
block_size=4,
num_anchors=8,
mask_token_id=0,
attention_backend="sdpa",
):
"""Build a tiny OnlineDFlashModel on cuda, mirroring scripts/train_dflash.build_models.

Returns (dflash_model, hidden_states_width, target_dir, target_layer_ids).
target_dir holds the saved tiny Qwen3 target (load it as an HF DFlash target for
the ONLINE path); target_layer_ids are the capture layers (== set_capture_layers).
For draft_layers=1 the capture set is one target layer so width == hidden.
"""
from transformers import AutoConfig, Qwen3Config, Qwen3ForCausalLM

from specforge.core.dflash import OnlineDFlashModel
from specforge.modeling.draft.dflash import DFlashDraftModel
from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead

# tiny Qwen3 target saved to disk (draft config is derived from it, as in train_dflash)
tcfg = Qwen3Config(
hidden_size=hidden,
intermediate_size=2 * hidden,
num_hidden_layers=target_layers,
num_attention_heads=4,
num_key_value_heads=2,
vocab_size=vocab,
max_position_embeddings=512,
rms_norm_eps=1e-5,
tie_word_embeddings=False,
)
torch.manual_seed(1234)
target_dir = os.path.join(workdir, "dflash_target")
Qwen3ForCausalLM(tcfg).save_pretrained(target_dir)

draft_config = AutoConfig.from_pretrained(target_dir)
draft_config.num_hidden_layers = draft_layers
draft_config.block_size = block_size
draft_config.num_target_layers = target_layers
draft_config.dflash_config = {"mask_token_id": mask_token_id}
draft_config._attn_implementation = attention_backend

draft_model = DFlashDraftModel(draft_config).to(device="cuda", dtype=torch.bfloat16)
draft_model.mask_token_id = mask_token_id

target_components = TargetEmbeddingsAndHead.from_pretrained(
target_dir, lm_head_key="lm_head.weight", device="cuda", dtype=torch.bfloat16
)

dflash_model = OnlineDFlashModel(
draft_model=draft_model,
target_lm_head=target_components.lm_head,
target_embed_tokens=target_components.embed_tokens,
block_size=draft_model.block_size,
mask_token_id=mask_token_id,
attention_backend=attention_backend,
num_anchors=num_anchors,
loss_type="dflash",
).cuda()
width = len(draft_model.target_layer_ids) * hidden
return dflash_model, width, target_dir, list(draft_model.target_layer_ids)
Loading
Loading