Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions examples/disagg/run_disagg_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
manifest. No model, no trainer.
* **consumer** (training pool): waits for the manifest, builds the EAGLE3 model
exactly as the offline launcher does, and trains via
``build_disagg_eagle3_runtime`` reading features from the shared store.
``build_disagg_offline_runtime`` reading features from the shared store.

The control plane carries only ``SampleRef`` metadata across the boundary; the
feature tensors travel through the shared store. Disaggregation changes *where*
Expand Down Expand Up @@ -70,10 +70,7 @@
from specforge.runtime.data_plane.disaggregated import AuthPolicy, SharedDirFeatureStore
from specforge.runtime.data_plane.feature_store import FeatureStore
from specforge.runtime.data_plane.mooncake_store import MooncakeFeatureStore
from specforge.runtime.launch import (
build_disagg_eagle3_runtime,
build_offline_eagle3_runtime,
)
from specforge.runtime.launch import build_disagg_offline_runtime, build_offline_runtime

RUN_ID = "eagle3-disagg"

Expand Down Expand Up @@ -227,7 +224,8 @@ def run_colocated(args) -> None:
sanity_check(args)
eagle3_model, target_head, optimizer_factory = _build_model_and_optimizer(args)
print(f"[colocated] training from {args.train_hidden_states_path}", flush=True)
trainer, loader = build_offline_eagle3_runtime(
trainer, loader = build_offline_runtime(
strategy="eagle3",
hidden_states_path=args.train_hidden_states_path,
eagle3_model=eagle3_model,
target_head=target_head,
Expand Down Expand Up @@ -277,7 +275,8 @@ def run_consumer(args) -> None:
location = getattr(store, "root", f"mooncake://{store.store_id}")
print(f"[consumer] training from {len(refs)} disagg refs in {location}", flush=True)

trainer, loader = build_disagg_eagle3_runtime(
trainer, loader = build_disagg_offline_runtime(
strategy="eagle3",
feature_store=store,
refs=refs,
eagle3_model=eagle3_model,
Expand Down
52 changes: 27 additions & 25 deletions scripts/train_eagle3_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def optimizer_factory(draft_module):
logger = lambda m, s: print(f"step {s}: {m}", flush=True)

if online:
from specforge.runtime.launch import build_online_eagle3_runtime
# `strategy=` selects the draft model (here eagle3); the runtime resolves
# its StrategySpec. The topology is the builder; the model is a parameter.
from specforge.runtime.launch import build_online_runtime

# Online target produces features in-loop (any backend exposing
# generate_eagle3_data — HF or SGLang). is_online=True returns the model.
Expand All @@ -143,37 +145,37 @@ def optimizer_factory(draft_module):

# num_epochs=1: the rollout output is a consume-once stream. Multi-epoch
# online (re-rollout each epoch) is a follow-up; one rollout pass here.
trainer, loader, workers, controller, drive_rollout = (
build_online_eagle3_runtime(
target_model=target_model,
prompts=prompts,
eagle3_model=eagle3_model,
optimizer_factory=optimizer_factory,
run_id="eagle3-online",
output_dir=args.output_dir,
target_hidden_size=hidden_size,
target_vocab_size=vocab_size,
target_repr="logits",
ttt_length=args.ttt_length,
batch_size=args.target_batch_size,
accumulation_steps=args.draft_accumulation_steps,
num_epochs=1,
max_steps=args.max_num_steps,
save_interval=args.save_interval,
tp_size=args.tp_size,
sp_ulysses_size=args.sp_ulysses_size,
sp_ring_size=args.sp_ring_size,
logger=logger,
)
trainer, loader, workers, controller, drive_rollout = build_online_runtime(
strategy="eagle3",
target_model=target_model,
prompts=prompts,
eagle3_model=eagle3_model,
optimizer_factory=optimizer_factory,
run_id="eagle3-online",
output_dir=args.output_dir,
target_hidden_size=hidden_size,
target_vocab_size=vocab_size,
target_repr="logits",
ttt_length=args.ttt_length,
batch_size=args.target_batch_size,
accumulation_steps=args.draft_accumulation_steps,
num_epochs=1,
max_steps=args.max_num_steps,
save_interval=args.save_interval,
tp_size=args.tp_size,
sp_ulysses_size=args.sp_ulysses_size,
sp_ring_size=args.sp_ring_size,
logger=logger,
)
produced = drive_rollout()
print(f"[online] rollout produced {produced} samples", flush=True)
trainer.fit(loader)
else:
from specforge.runtime.launch import build_offline_eagle3_runtime
from specforge.runtime.launch import build_offline_runtime

target_head, _ = build_target_model(args, draft_config, is_online=False)
trainer, loader = build_offline_eagle3_runtime(
trainer, loader = build_offline_runtime(
strategy="eagle3",
hidden_states_path=args.train_hidden_states_path,
eagle3_model=eagle3_model,
target_head=target_head,
Expand Down
Loading
Loading