Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 120 additions & 35 deletions scripts/train_eagle3_dataflow.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
"""Thin launcher: offline EAGLE3 training through the SpecForge DataFlow runtime.
"""Thin launcher: EAGLE3 training through the SpecForge DataFlow runtime.

This script is a *launcher* (M3): it builds models + optimizer, hands them to the
runtime, and runs ``TrainerController.fit``. No training logic lives here — the
loop, loss, projection, checkpoint, and eval all live in ``specforge.runtime``.

Reuses the existing model/data builders from ``scripts.train_eagle3`` so model
construction stays DRY; only the *orchestration* moves behind the runtime.
construction stays DRY; only the *orchestration* moves behind the runtime. Both
modes converge at ``SampleRef`` and share one trainer/strategy/FSDP path:

Example (offline):
* **offline** (``--train-hidden-states-path`` set): an ``OfflineManifestReader``
turns precomputed ``.ckpt`` files into refs.
* **online** (no hidden-states path): a ``RolloutWorker`` generates features from
the target model and commits refs onto the control plane's queue.

Examples:
# offline
torchrun --standalone --nproc_per_node 1 scripts/train_eagle3_dataflow.py \
--target-model-path <hf-model> --draft-model-config configs/llama3-8B-eagle3.json \
--target-model-path <hf-model> --draft-model-config <cfg.json> \
--train-data-path <prompts.jsonl> --train-hidden-states-path <features_dir> \
--output-dir ./output --max-num-steps 20

# online (no --train-hidden-states-path)
torchrun --standalone --nproc_per_node 1 scripts/train_eagle3_dataflow.py \
--target-model-path <hf-model> --draft-model-config <cfg.json> \
--train-data-path <prompts.jsonl> --output-dir ./output --max-num-steps 20
"""

from accelerate.utils import set_seed
Expand All @@ -26,13 +38,46 @@

from specforge.distributed import destroy_distributed, init_distributed
from specforge.optimizer import BF16Optimizer
from specforge.runtime.launch import build_offline_eagle3_runtime


def _target_hidden_and_vocab(target_model):
"""Best-effort (hidden_size, vocab_size) from an Eagle3 target backend."""
cfg = getattr(getattr(target_model, "model", None), "config", None)
if cfg is not None:
return int(cfg.hidden_size), int(cfg.vocab_size)
raise RuntimeError(
"could not read hidden_size/vocab_size from the target model; pass them explicitly"
)


def _extract_prompts(train_dataloader):
"""Flatten the online train dataloader into metadata-only PromptTask payloads.

Each prompt carries only ``input_ids`` + ``loss_mask`` (the control plane
never holds tensors); ``attention_mask`` recovers the true unpadded length.
"""
prompts = []
for batch in train_dataloader:
input_ids = batch["input_ids"]
loss_mask = batch["loss_mask"]
attn = batch.get("attention_mask")
for i in range(input_ids.shape[0]):
n = int(attn[i].sum().item()) if attn is not None else input_ids.shape[1]
prompts.append(
{
"payload": {
"input_ids": input_ids[i, :n].tolist(),
"loss_mask": loss_mask[i, :n].tolist(),
}
}
)
return prompts


def main():
parser, args = parse_args()
# parse_args() does not derive target_batch_size (train_eagle3.main computes
# it inline before building dataloaders); the offline runtime builder and
# it inline before building dataloaders); the runtime builder and
# build_dataloaders both read it, so derive it here too.
args.target_batch_size = args.tp_size * args.batch_size

Expand All @@ -55,17 +100,12 @@ def main():
sp_ring_size=args.sp_ring_size,
sp_ulysses_size=args.sp_ulysses_size,
)
if args.train_hidden_states_path is None:
raise SystemExit(
"train_eagle3_dataflow currently wires the OFFLINE path; pass "
"--train-hidden-states-path. (Online wiring composes RolloutWorker + "
"SGLangAdapter over the same control/data plane.)"
)

online = args.train_hidden_states_path is None

draft_config, draft_model, _ckpt, _resume = build_draft_model(args)
target_head, _ = build_target_model(args, draft_config, is_online=False)
# vocab mapping is produced from the prompt dataset exactly as today
_train, vocab_mapping_path, _eval = build_dataloaders(args, draft_config)
train_dataloader, vocab_mapping_path, _eval = build_dataloaders(args, draft_config)
draft_model.load_vocab_mapping(vocab_mapping_path)

from specforge import OnlineEagle3Model
Expand All @@ -79,7 +119,7 @@ def main():
kl_decay=args.kl_decay,
).cuda()

# built AFTER FSDP-wrap (inside the runtime) over the wrapped inner draft
# optimizer is built AFTER FSDP-wrap (inside the runtime) over the inner draft
def optimizer_factory(draft_module):
return BF16Optimizer(
draft_module,
Expand All @@ -89,26 +129,71 @@ def optimizer_factory(draft_module):
total_steps=args.total_steps or 10_000,
)

trainer, loader = build_offline_eagle3_runtime(
hidden_states_path=args.train_hidden_states_path,
eagle3_model=eagle3_model,
target_head=target_head,
optimizer_factory=optimizer_factory,
run_id="eagle3-offline",
output_dir=args.output_dir,
ttt_length=args.ttt_length,
max_len=args.max_length,
batch_size=args.target_batch_size,
accumulation_steps=args.draft_accumulation_steps,
num_epochs=args.num_epochs,
max_steps=args.max_num_steps,
save_interval=args.save_interval,
tp_size=args.tp_size,
sp_ulysses_size=args.sp_ulysses_size,
sp_ring_size=args.sp_ring_size,
logger=lambda m, s: print(f"step {s}: {m}", flush=True),
)
trainer.fit(loader)
logger = lambda m, s: print(f"step {s}: {m}", flush=True)

if online:
from specforge.runtime.launch import build_online_eagle3_runtime

# Online target produces features in-loop (any backend exposing
# generate_eagle3_data — HF or SGLang). is_online=True returns the model.
target_model, _ = build_target_model(args, draft_config, is_online=True)
hidden_size, vocab_size = _target_hidden_and_vocab(target_model)
prompts = _extract_prompts(train_dataloader)
print(f"[online] ingesting {len(prompts)} prompts for rollout", flush=True)

# num_epochs=1: the rollout output is a consume-once stream. Multi-epoch
# online (re-rollout each epoch) is a follow-up; one rollout pass here.
trainer, loader, workers, controller, drive_rollout = (
build_online_eagle3_runtime(
target_model=target_model,
prompts=prompts,
eagle3_model=eagle3_model,
optimizer_factory=optimizer_factory,
run_id="eagle3-online",
output_dir=args.output_dir,
target_hidden_size=hidden_size,
target_vocab_size=vocab_size,
target_repr="logits",
ttt_length=args.ttt_length,
batch_size=args.target_batch_size,
accumulation_steps=args.draft_accumulation_steps,
num_epochs=1,
max_steps=args.max_num_steps,
save_interval=args.save_interval,
tp_size=args.tp_size,
sp_ulysses_size=args.sp_ulysses_size,
sp_ring_size=args.sp_ring_size,
logger=logger,
)
)
produced = drive_rollout()
print(f"[online] rollout produced {produced} samples", flush=True)
trainer.fit(loader)
else:
from specforge.runtime.launch import build_offline_eagle3_runtime

target_head, _ = build_target_model(args, draft_config, is_online=False)
trainer, loader = build_offline_eagle3_runtime(
hidden_states_path=args.train_hidden_states_path,
eagle3_model=eagle3_model,
target_head=target_head,
optimizer_factory=optimizer_factory,
run_id="eagle3-offline",
output_dir=args.output_dir,
ttt_length=args.ttt_length,
max_len=args.max_length,
batch_size=args.target_batch_size,
accumulation_steps=args.draft_accumulation_steps,
num_epochs=args.num_epochs,
max_steps=args.max_num_steps,
save_interval=args.save_interval,
tp_size=args.tp_size,
sp_ulysses_size=args.sp_ulysses_size,
sp_ring_size=args.sp_ring_size,
logger=logger,
)
trainer.fit(loader)

destroy_distributed()


Expand Down
159 changes: 158 additions & 1 deletion specforge/runtime/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,165 @@ def build_offline_eagle3_runtime(
return trainer, loader


def build_online_eagle3_runtime(
*,
target_model,
prompts,
eagle3_model,
optimizer_factory,
run_id: str,
output_dir: str,
target_hidden_size: int,
target_vocab_size: Optional[int] = None,
draft_vocab_size: Optional[int] = None,
target_repr: str = "logits",
aux_hidden_state_layer_ids=None,
vocab_map_version: Optional[str] = None,
t2d=None,
num_rollout_workers: int = 1,
device: str = "cuda",
ttt_length: int = 7,
batch_size: int = 1,
accumulation_steps: int = 1,
num_epochs: int = 1,
max_steps: Optional[int] = None,
save_interval: int = 0,
eval_interval: int = 0,
tp_size: int = 1,
sp_ulysses_size: int = 1,
sp_ring_size: int = 1,
collate_fn=None,
logger=None,
):
"""Assemble the online-EAGLE3 dataflow and return
``(trainer, loader, workers, controller, drive_rollout)``.

Mirror of :func:`build_offline_eagle3_runtime`; the only difference is the
*producer* of ``SampleRef``s. Instead of an ``OfflineManifestReader`` reading
``.ckpt`` files, a ``RolloutWorker`` leases ``PromptTask``s, asks the
``target_model`` (any backend exposing ``generate_eagle3_data`` — HF, SGLang,
or custom; **sglang is not required**) for per-sample features via
``SGLangAdapter``, writes them to the ``mem://`` ``FeatureStore``, and commits
``SampleRef``s onto the controller's ``SampleRefQueue``. From ``SampleRef``
down (loader -> strategy -> trainer) the code path is identical to offline.

``prompts`` is the metadata-only PromptTask source (e.g.
``[{"payload": {"input_ids": [...], "loss_mask": [...]}}]``). The returned
``drive_rollout()`` runs the workers until the prompt pool is exhausted,
populating the queue the loader consumes; the launcher script calls it before
``trainer.fit(loader)``. (Fully-async rollout/train interleaving with
backpressure is the control-plane's job — a follow-up, not this seam.)

``target_head`` is ``None`` on purpose: online rollout already materialized the
``target`` distribution, so the strategy consumes it directly rather than
re-running an lm-head (that is the offline ``hidden_state`` path's job).
"""
import torch

from specforge.runtime.inference.capture import CaptureConfig
from specforge.runtime.inference.rollout_worker import RolloutWorker
from specforge.runtime.inference.sglang_adapter import SGLangAdapter

controller = DataFlowController(run_id)
controller.ingest_prompts(prompts)
# PR8 colocated store has no residency cap (max_resident_bytes is the M5
# backpressure follow-up); mirror the offline launcher's plain construction.
store = LocalFeatureStore(run_id)

if aux_hidden_state_layer_ids is None:
aux_hidden_state_layer_ids = tuple(
getattr(target_model, "aux_hidden_states_layers", ()) or ()
)

adapter = SGLangAdapter(target_model, device=device, t2d=t2d)
capture = CaptureConfig.from_strategy(
required_features=Eagle3TrainStrategy.required_features,
aux_hidden_state_layer_ids=tuple(aux_hidden_state_layer_ids),
target_repr=target_repr,
target_hidden_size=target_hidden_size,
target_vocab_size=target_vocab_size,
draft_vocab_size=draft_vocab_size,
vocab_map_version=vocab_map_version,
)
workers = [
RolloutWorker(
controller,
store,
adapter,
capture,
run_id=run_id,
worker_id=f"rollout-{i}",
)
for i in range(num_rollout_workers)
]

# Queue mode (online consume-once stream). Online features arrive from the
# adapter already in train form (input_ids/attention_mask/loss_mask/
# hidden_state/target), so there is no per_sample_transform (unlike offline).
def _cat_collate(feats):
# Concatenate per-sample features along the batch dim. The offline
# ``DataCollatorWithPadding`` assumes 2D (B,n) inputs and would choke on
# the 3D hidden_state/target tensors; online features are pre-formed, so
# a plain cat is correct for equal-length / batch_size=1 batches (the
# adapter already groups equal-length prompts). Variable-length padded
# batching is a follow-up; pass ``collate_fn`` to override.
return {k: torch.cat([f[k] for f in feats], dim=0) for k in feats[0]}

loader = FeatureDataLoader(
store,
controller.sample_queue,
batch_size=batch_size,
collate_fn=collate_fn or _cat_collate,
drop_last=True,
strategy="eagle3",
)

parallel = ParallelConfig.from_distributed(
tp_size=tp_size, sp_ulysses_size=sp_ulysses_size, sp_ring_size=sp_ring_size
)
backend = FSDPTrainingBackend(parallel, optimizer_factory=optimizer_factory)
wrapped = backend.prepare_model(
eagle3_model, optimizer_target=eagle3_model.draft_model
)
strategy = Eagle3TrainStrategy(wrapped, target_head=None)
core = TrainerCore(strategy, backend, accumulation_steps=accumulation_steps)
trainer_id = controller.register_trainer({"role": "trainer", "run_id": run_id})
trainer = TrainerController(
core,
run_id=run_id,
output_dir=output_dir,
num_epochs=num_epochs,
max_steps=max_steps,
save_interval=save_interval,
eval_interval=eval_interval,
logger=logger,
ack_fn=lambda ids, step: controller.ack_train_refs(
trainer_id, ids, global_step=step, optimizer_durable=True
),
)

def drive_rollout(max_rounds: int = 100_000) -> int:
"""Run the workers until the prompt pool drains; returns refs produced."""
for w in workers:
w.start()
produced = 0
lease = max(batch_size * 8, 8)
for _ in range(max_rounds):
got = sum(len(w.run_once(max_tasks=lease)) for w in workers)
if got == 0:
break
produced += got
return produced

return trainer, loader, workers, controller, drive_rollout


# Backward-compatible alias for early branch users.
build_offline_eagle3_controller = build_offline_eagle3_runtime


__all__ = ["build_offline_eagle3_controller", "build_offline_eagle3_runtime"]
__all__ = [
"build_offline_eagle3_controller",
"build_offline_eagle3_runtime",
"build_online_eagle3_runtime",
]
Loading