Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions scripts/train_domino.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,13 +429,11 @@ def get_lambda_base(
lambda_start: float = 1.0,
decay_ratio: float = 0.5,
) -> float:
decay_steps = max(1, int(total_steps * decay_ratio))
progress = min(global_step / decay_steps, 1.0)
lambda_base = lambda_start * (1.0 - progress)
# Delegates to the runtime's single source of the Domino lambda schedule so the
# standalone script and DominoTrainStrategy cannot drift.
from specforge.runtime.training.strategy import linear_lambda_base

# Clamp to [0, 1].
lambda_base = max(0.0, min(1.0, lambda_base))
return lambda_base
return linear_lambda_base(global_step, total_steps, lambda_start, decay_ratio)


def main():
Expand Down
2 changes: 1 addition & 1 deletion specforge/runtime/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

RunMode = Literal["online", "offline"]
DeploymentMode = Literal["local_colocated", "dataflow_colocated", "disaggregated"]
DraftStrategyName = Literal["eagle3", "dflash"]
DraftStrategyName = Literal["eagle3", "dflash", "domino"]
# Tagged union for the EAGLE3 target feature. The *strategy* owns the
# projection so the trainer core stays branch-free:
# - pruned_logits: rollout applied the t2d vocab map; stored (seq, draft_vocab)
Expand Down
12 changes: 12 additions & 0 deletions specforge/runtime/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _assemble_trainer(
accumulation_steps: int,
num_epochs: int,
max_steps: Optional[int],
total_steps: Optional[int] = None,
save_interval: int,
eval_interval: int,
tp_size: int,
Expand Down Expand Up @@ -124,6 +125,7 @@ def _assemble_trainer(
output_dir=output_dir,
num_epochs=num_epochs,
max_steps=max_steps,
total_steps=total_steps,
save_interval=save_interval,
eval_interval=eval_interval,
log_interval=log_interval,
Expand Down Expand Up @@ -279,6 +281,7 @@ def build_offline_runtime(
accumulation_steps: int = 1,
num_epochs: int = 1,
max_steps: Optional[int] = None,
total_steps: Optional[int] = None,
save_interval: int = 0,
eval_interval: int = 0,
tp_size: int = 1,
Expand Down Expand Up @@ -320,6 +323,7 @@ def build_offline_runtime(
accumulation_steps=accumulation_steps,
num_epochs=num_epochs,
max_steps=max_steps,
total_steps=total_steps,
save_interval=save_interval,
eval_interval=eval_interval,
tp_size=tp_size,
Expand Down Expand Up @@ -347,6 +351,7 @@ def build_disagg_offline_runtime(
accumulation_steps: int = 1,
num_epochs: int = 1,
max_steps: Optional[int] = None,
total_steps: Optional[int] = None,
save_interval: int = 0,
eval_interval: int = 0,
tp_size: int = 1,
Expand Down Expand Up @@ -380,6 +385,7 @@ def build_disagg_offline_runtime(
accumulation_steps=accumulation_steps,
num_epochs=num_epochs,
max_steps=max_steps,
total_steps=total_steps,
save_interval=save_interval,
eval_interval=eval_interval,
tp_size=tp_size,
Expand Down Expand Up @@ -420,6 +426,7 @@ def build_online_runtime(
accumulation_steps: int = 1,
num_epochs: int = 1,
max_steps: Optional[int] = None,
total_steps: Optional[int] = None,
save_interval: int = 0,
eval_interval: int = 0,
tp_size: int = 1,
Expand Down Expand Up @@ -480,6 +487,7 @@ def build_online_runtime(
accumulation_steps=accumulation_steps,
num_epochs=num_epochs,
max_steps=max_steps,
total_steps=total_steps,
save_interval=save_interval,
eval_interval=eval_interval,
tp_size=tp_size,
Expand Down Expand Up @@ -617,6 +625,7 @@ def build_disagg_online_consumer(
accumulation_steps: int = 1,
num_epochs: int = 1,
max_steps: Optional[int] = None,
total_steps: Optional[int] = None,
save_interval: int = 0,
eval_interval: int = 0,
tp_size: int = 1,
Expand Down Expand Up @@ -678,6 +687,7 @@ def build_disagg_online_consumer(
accumulation_steps=accumulation_steps,
num_epochs=num_epochs,
max_steps=max_steps,
total_steps=total_steps,
save_interval=save_interval,
eval_interval=eval_interval,
tp_size=tp_size,
Expand Down Expand Up @@ -788,6 +798,7 @@ def build_disagg_online_runtime(
accumulation_steps: int = 1,
num_epochs: int = 1,
max_steps: Optional[int] = None,
total_steps: Optional[int] = None,
save_interval: int = 0,
eval_interval: int = 0,
tp_size: int = 1,
Expand Down Expand Up @@ -867,6 +878,7 @@ def build_disagg_online_runtime(
accumulation_steps=accumulation_steps,
num_epochs=num_epochs,
max_steps=max_steps,
total_steps=total_steps,
save_interval=save_interval,
eval_interval=eval_interval,
tp_size=tp_size,
Expand Down
41 changes: 41 additions & 0 deletions specforge/runtime/training/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,47 @@ def _dflash_adapter(target_model, *, device="cuda", t2d=None):
)


# --- Domino -----------------------------------------------------------------
# Domino reuses DFlash's draft model (projector_type="domino" head), feature
# schema, offline transform/collate, and capture adapter (same
# generate_dflash_data -> hidden_states). The ONE difference is the loss: it
# blends a base loss with a step-decayed weight, so DominoTrainStrategy reads the
# StepContext (forward_loss(batch, ctx)). That is the whole reason a new algorithm
# needs anything beyond a spec entry here.

from specforge.runtime.training.strategy import DominoTrainStrategy


def _domino_offline_reader(hidden_states_path, *, run_id, ttt_length, max_len):
from specforge.runtime.data_plane.offline_reader import OfflineManifestReader

return OfflineManifestReader(
hidden_states_path,
run_id=run_id,
ttt_length=ttt_length,
max_len=max_len,
strategy="domino",
feature_keys=("input_ids", "loss_mask", "hidden_states"),
target_repr=None,
)


register_strategy(
StrategySpec(
name="domino",
required_features=frozenset(DominoTrainStrategy.required_features),
make_strategy=lambda wrapped, *, target_head=None: DominoTrainStrategy(wrapped),
uses_target_head=False,
make_offline_reader=_domino_offline_reader,
make_offline_transform=_dflash_offline_transform, # same schema as DFlash
make_offline_collate=_dflash_offline_collate,
make_online_collate=lambda: concat_collate,
make_adapter=_dflash_adapter, # same generate_dflash_data capture
supports_online=True,
)
)


__all__ = [
"StrategySpec",
"concat_collate",
Expand Down
120 changes: 117 additions & 3 deletions specforge/runtime/training/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,39 @@ class StepOutput:
metrics: Dict[str, Any]


@dataclass(frozen=True)
class StepContext:
"""Training-schedule state passed alongside the batch into ``forward_loss``.

Most strategies ignore it (their loss is a pure function of the batch). It
exists for objectives whose loss depends on *where in training* we are — e.g.
Domino blends a base loss with a weight ``lambda_base`` that decays over
``total_steps``. Threading this explicitly keeps schedule-state out of ad-hoc
kwargs and off the model's forward signature at the runtime seam.
"""

global_step: int = 0
total_steps: Optional[int] = None


def linear_lambda_base(
global_step: int,
total_steps: int,
lambda_start: float = 1.0,
decay_ratio: float = 0.5,
) -> float:
"""Domino base-loss weight: linear decay from ``lambda_start`` to 0 over the
first ``total_steps * decay_ratio`` steps, then 0, clamped to ``[0, 1]``.

Single source of the schedule for both the runtime ``DominoTrainStrategy`` and
``scripts/train_domino.py`` so the two cannot drift. Requires a real
``total_steps`` (> 0); callers with no schedule horizon decide the fallback.
"""
decay_steps = max(1, int(total_steps * decay_ratio))
progress = min(global_step / decay_steps, 1.0)
return max(0.0, min(1.0, lambda_start * (1.0 - progress)))


class DraftTrainStrategy(abc.ABC):
name: str
required_features: set
Expand All @@ -59,7 +92,9 @@ def validate_batch(self, batch: TrainBatch) -> None:
)

@abc.abstractmethod
def forward_loss(self, batch: TrainBatch) -> StepOutput: ...
def forward_loss(
self, batch: TrainBatch, ctx: Optional["StepContext"] = None
) -> StepOutput: ...

def checkpoint_state_filter(self, state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Select the keys this strategy persists as draft weights."""
Expand Down Expand Up @@ -128,7 +163,9 @@ def _prepare_target(
# applied any shift; use the tensors as delivered.
return input_ids.to(device), target.to(device), loss_mask.to(device)

def forward_loss(self, batch: TrainBatch) -> StepOutput:
def forward_loss(
self, batch: TrainBatch, ctx: Optional[StepContext] = None
) -> StepOutput:
self.validate_batch(batch)
t = batch.tensors
device = self._device()
Expand Down Expand Up @@ -202,7 +239,9 @@ def trainable_module(self) -> nn.Module:
def _device(self) -> torch.device:
return next(self.dflash_model.parameters()).device

def forward_loss(self, batch: TrainBatch) -> StepOutput:
def forward_loss(
self, batch: TrainBatch, ctx: Optional[StepContext] = None
) -> StepOutput:
self.validate_batch(batch)
t = batch.tensors
device = self._device()
Expand All @@ -223,9 +262,84 @@ def checkpoint_state_filter(self, state_dict: Dict[str, Any]) -> Dict[str, Any]:
}


class DominoTrainStrategy(DraftTrainStrategy):
"""Domino block-parallel strategy wrapping ``OnlineDominoModel``.

Shares the trainer/backend/loader/checkpoint spine with DFlash and uses the
same feature schema ({input_ids, hidden_states, loss_mask}) and capture
adapter. The one thing Domino needs that the others don't: its loss blends a
base loss with a weight ``lambda_base`` that DECAYS over training, so it reads
the :class:`StepContext` to compute the schedule (every other strategy ignores
ctx). ``OnlineDominoModel.forward`` returns a (loss, accuracy, metrics) triple.
"""

name = "domino"
required_features = {"input_ids", "hidden_states", "loss_mask"}

def __init__(
self,
domino_model: nn.Module,
*,
lambda_start: float = 1.0,
decay_ratio: float = 0.5,
) -> None:
self.domino_model = domino_model
self.lambda_start = lambda_start
self.decay_ratio = decay_ratio

def trainable_module(self) -> nn.Module:
return self.domino_model

def _device(self) -> torch.device:
return next(self.domino_model.parameters()).device

def _lambda_base(self, ctx: Optional[StepContext]) -> float:
# Without schedule info (no ctx, or no known total_steps because neither
# total_steps nor max_steps was set on the controller) fall back to the pure
# final loss (lambda_base = 0). Otherwise use the shared linear schedule.
if ctx is None or not ctx.total_steps:
return 0.0
return linear_lambda_base(
ctx.global_step, ctx.total_steps, self.lambda_start, self.decay_ratio
)

def forward_loss(
self, batch: TrainBatch, ctx: Optional[StepContext] = None
) -> StepOutput:
self.validate_batch(batch)
t = batch.tensors
device = self._device()
lambda_base = self._lambda_base(ctx)
loss, accuracy, _metrics = self.domino_model(
input_ids=t["input_ids"].to(device),
hidden_states=t["hidden_states"].to(device),
loss_mask=t["loss_mask"].to(device),
lambda_base=lambda_base,
)
return StepOutput(
loss=loss,
metrics={
"accuracy": accuracy.detach(),
"lambda_base": torch.tensor(float(lambda_base)),
},
)

def checkpoint_state_filter(self, state_dict: Dict[str, Any]) -> Dict[str, Any]:
# Domino (like DFlash) keeps everything under draft_model.; the target
# embedding/head live in a separate module not persisted as draft weights.
return {
k.replace("draft_model.", ""): v
for k, v in state_dict.items()
if "draft_model." in k
}


__all__ = [
"DraftTrainStrategy",
"Eagle3TrainStrategy",
"DFlashTrainStrategy",
"DominoTrainStrategy",
"StepOutput",
"StepContext",
"linear_lambda_base",
]
Loading
Loading