Skip to content
72 changes: 72 additions & 0 deletions examples/disagg/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Disaggregated offline EAGLE3 example

Runs the offline EAGLE3 training of `scripts/train_eagle3_dataflow.py`, but splits
it across **two pools that share only a filesystem mount** — the M6 disaggregation
seam (`SharedDirFeatureStore`). It is the runnable proof that *disaggregation
changes where features live, not their values*: the training curve matches the
colocated offline run.

## How it works

```
producer pool (node 0) shared mount training pool (node 1)
───────────────────── ────────────── ──────────────────────
ingest_offline_features() ──put()──▶ SharedDirFeatureStore ──get()──▶ FeatureDataLoader
write_ref_manifest() ──json──▶ refs.json (no tensors) ──read──▶ build_disagg_eagle3_runtime
TrainerController.fit()
```

The control plane carries only tensor-free `SampleRef` metadata (the manifest);
feature tensors travel through the shared store. `build_disagg_eagle3_runtime`
reuses the exact offline trainer assembly, so results align by construction.

## Run it (rcli, 2 nodes)

1. Generate offline features on node 0 (any EAGLE3 feature generator), e.g. into
`/root/disagg/features` as `*.ckpt` with keys
`input_ids,loss_mask,hidden_state,aux_hidden_state`.
2. Drive both pools at once — node 0 ingests, node 1 trains:

```bash
rcli exec --per-node <job> 'bash examples/disagg/run_qwen2.5_7b_eagle3_disagg.sh'
```

The wrapper branches on `RCLI_NODE_RANK`. Override paths/steps via env
(`DISAGG_STORE_ROOT`, `FEATURES_DIR`, `MAX_STEPS`, `NPROC`, …). Both pools must
share `DISAGG_STORE_ROOT`/`DISAGG_STORE_ID` and (if set) `DISAGG_AUTH_TOKEN`
(B9 auth).

## Single-host smoke

`DISAGG_ROLE` overrides the rank-derived role, so you can run both halves on one
host sharing a local dir — run the producer once, then the consumer:

```bash
DISAGG_ROLE=producer python examples/disagg/run_disagg_eagle3.py <args>
DISAGG_ROLE=consumer torchrun --standalone --nproc_per_node 1 \
examples/disagg/run_disagg_eagle3.py <args>
```

The bit-exact equivalence to the colocated path is covered by
`tests/test_runtime/test_disagg_launch.py`.

## Head-to-head vs colocated (Qwen2.5-7B, 2-node H200)

`DISAGG_ROLE=colocated` runs the same model build + assembly through
`build_offline_eagle3_runtime` (`LocalFeatureStore`). On identical features/seed,
the disaggregated consumer and the colocated baseline produce the same training
metrics to ~5 significant figures (residual ~1e-6–1e-8 is GPU run-to-run
floating-point noise, not the transport — feature tensors are byte-identical):

| step | metric | disagg | colocated |
|---|---|---|---|
| 20 | acceptance_rate | 0.0013300 | 0.0013300 |
| 20 | ploss | 5.386736 | 5.386740 |
| 20 | acc | 0.0272590 | 0.0272590 |
| 120 | acceptance_rate | 0.0223610 | 0.0223505 |
| 180 | acceptance_rate | 0.0337013 | 0.0336982 |

acc / acceptance_rate climb over training in both (baseline direction). Per-step
values are noisy at `batch_size=1` over 64 diverse samples. Note this is the
training-time acceptance proxy; the serving accept-length (τ via spec-decoding) is
a separate eval gate.
248 changes: 248 additions & 0 deletions examples/disagg/run_disagg_eagle3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
"""Disaggregated offline EAGLE3 example: producer and consumer on different pools.

This is the *assemble* example for the M6 disaggregation seam. It runs the SAME
offline EAGLE3 training as ``scripts/train_eagle3_dataflow.py`` (reusing its model
builders, so results align with the colocated run), but splits the work across two
pools that share only a filesystem mount:

* **producer** (rollout/feature pool): ``ingest_offline_features`` loads the
precomputed ``.ckpt`` features and ``put()``s them into a
``SharedDirFeatureStore`` on the shared mount, then publishes a tensor-free ref
manifest. No model, no trainer.
* **consumer** (training pool): waits for the manifest, builds the EAGLE3 model
exactly as the offline launcher does, and trains via
``build_disagg_eagle3_runtime`` reading features from the shared store.

The control plane carries only ``SampleRef`` metadata across the boundary; the
feature tensors travel through the shared store. Disaggregation changes *where*
features live, not their values, so the training curve matches the colocated
offline baseline.

Role is taken from ``DISAGG_ROLE`` (``producer``/``consumer``); if unset it is
derived from ``RCLI_NODE_RANK`` (0 -> producer, else consumer). Shared paths +
auth come from the environment so one wrapper can drive both nodes:

DISAGG_STORE_ROOT=/workspace/disagg_store # shared mount, both pools
DISAGG_MANIFEST=/workspace/disagg_store/refs.json
DISAGG_STORE_ID=eagle3-disagg # producer/consumer must match
DISAGG_AUTH_TOKEN=<secret> # optional (B9 auth)
"""

import os
import time

from accelerate.utils import set_seed

# reuse the existing builders so model construction matches the offline path
from train_eagle3 import (
build_dataloaders,
build_draft_model,
build_target_model,
parse_args,
sanity_check,
)

from specforge.distributed import destroy_distributed, init_distributed
from specforge.optimizer import BF16Optimizer
from specforge.runtime.data_plane.disagg_ingest import (
ingest_offline_features,
read_ref_manifest,
write_ref_manifest,
)
from specforge.runtime.data_plane.disaggregated import AuthPolicy, SharedDirFeatureStore
from specforge.runtime.launch import (
build_disagg_eagle3_runtime,
build_offline_eagle3_runtime,
)

RUN_ID = "eagle3-disagg"


def _role() -> str:
role = os.environ.get("DISAGG_ROLE")
if role:
return role
return "producer" if os.environ.get("RCLI_NODE_RANK", "0") == "0" else "consumer"


def _store(args, *, retain_on_release: bool = False) -> SharedDirFeatureStore:
token = os.environ.get("DISAGG_AUTH_TOKEN") or None
return SharedDirFeatureStore(
os.environ["DISAGG_STORE_ROOT"],
store_id=os.environ.get("DISAGG_STORE_ID", RUN_ID),
auth=AuthPolicy(token),
credential=token,
retain_on_release=retain_on_release,
)


def run_producer(args) -> None:
manifest = os.environ["DISAGG_MANIFEST"]
store = _store(args)
refs = ingest_offline_features(
store,
args.train_hidden_states_path,
run_id=RUN_ID,
ttt_length=args.ttt_length,
max_len=args.max_length,
)
write_ref_manifest(refs, manifest)
open(manifest + ".done", "w").close() # liveness marker the consumer waits on
print(
f"[producer] ingested {len(refs)} samples into {store.root}; "
f"manifest -> {manifest}",
flush=True,
)


def _build_model_and_optimizer(args):
"""Identical EAGLE3 model/optimizer build for both consumer and colocated.

Sharing this keeps the disaggregated and colocated runs apples-to-apples
(same draft, same target_head, same optimizer) so any metric difference can
only come from the feature transport, not the model.
"""
draft_config, draft_model, _ckpt, _resume = build_draft_model(args)
target_head, _ = build_target_model(args, draft_config, is_online=False)
_train, vocab_mapping_path, _eval = build_dataloaders(args, draft_config)
draft_model.load_vocab_mapping(vocab_mapping_path)

from specforge import OnlineEagle3Model

eagle3_model = OnlineEagle3Model(
draft_model=draft_model,
length=args.ttt_length,
attention_backend=args.attention_backend,
lk_loss_type=args.lk_loss_type,
kl_scale=args.kl_scale,
kl_decay=args.kl_decay,
).cuda()

def optimizer_factory(draft_module):
return BF16Optimizer(
draft_module,
lr=args.learning_rate,
max_grad_norm=args.max_grad_norm,
warmup_ratio=args.warmup_ratio,
total_steps=args.total_steps or 10_000,
)

return eagle3_model, target_head, optimizer_factory


def _log_interval() -> int:
return int(os.environ.get("DISAGG_LOG_INTERVAL", "25"))


def run_colocated(args) -> None:
"""Baseline: same model + assembly via build_offline (LocalFeatureStore).

For a head-to-head accept-length/loss comparison against the disaggregated
consumer on identical features/seed.
"""
init_distributed(
timeout=args.dist_timeout,
tp_size=args.tp_size,
sp_ring_size=args.sp_ring_size,
sp_ulysses_size=args.sp_ulysses_size,
)
sanity_check(args)
eagle3_model, target_head, optimizer_factory = _build_model_and_optimizer(args)
print(f"[colocated] training from {args.train_hidden_states_path}", flush=True)
trainer, loader = build_offline_eagle3_runtime(
hidden_states_path=args.train_hidden_states_path,
eagle3_model=eagle3_model,
target_head=target_head,
optimizer_factory=optimizer_factory,
run_id="eagle3-colocated",
output_dir=args.output_dir,
ttt_length=args.ttt_length,
max_len=args.max_length,
batch_size=args.target_batch_size,
accumulation_steps=args.draft_accumulation_steps,
num_epochs=args.num_epochs,
max_steps=args.max_num_steps,
save_interval=args.save_interval,
tp_size=args.tp_size,
sp_ulysses_size=args.sp_ulysses_size,
sp_ring_size=args.sp_ring_size,
logger=lambda m, s: print(f"step {s}: {m}", flush=True),
log_interval=_log_interval(),
)
trainer.fit(loader)
destroy_distributed()


def run_consumer(args) -> None:
manifest = os.environ["DISAGG_MANIFEST"]
init_distributed(
timeout=args.dist_timeout,
tp_size=args.tp_size,
sp_ring_size=args.sp_ring_size,
sp_ulysses_size=args.sp_ulysses_size,
)
sanity_check(
args
) # derives target_batch_size/dp_size the builders read (needs dist)
# wait for the producer to publish the manifest (shared mount)
deadline = time.monotonic() + 1800
while not os.path.exists(manifest + ".done"):
if time.monotonic() > deadline:
raise SystemExit(f"[consumer] timed out waiting for {manifest}.done")
time.sleep(2)

eagle3_model, target_head, optimizer_factory = _build_model_and_optimizer(args)

# offline ref set is re-iterated across epochs -> retain on release (read-only)
store = _store(args, retain_on_release=True)
refs = read_ref_manifest(manifest)
print(
f"[consumer] training from {len(refs)} disagg refs in {store.root}", flush=True
)

trainer, loader = build_disagg_eagle3_runtime(
feature_store=store,
refs=refs,
eagle3_model=eagle3_model,
target_head=target_head,
optimizer_factory=optimizer_factory,
run_id=RUN_ID,
output_dir=args.output_dir,
max_len=args.max_length,
batch_size=args.target_batch_size,
accumulation_steps=args.draft_accumulation_steps,
num_epochs=args.num_epochs,
max_steps=args.max_num_steps,
save_interval=args.save_interval,
tp_size=args.tp_size,
sp_ulysses_size=args.sp_ulysses_size,
sp_ring_size=args.sp_ring_size,
logger=lambda m, s: print(f"step {s}: {m}", flush=True),
log_interval=_log_interval(),
)
trainer.fit(loader)
destroy_distributed()


def main() -> None:
parser, args = parse_args()
set_seed(args.seed)
if args.train_hidden_states_path is None:
raise SystemExit(
"disagg example wires the OFFLINE path; pass --train-hidden-states-path"
)
role = _role()
print(
f"[disagg] role={role} node_rank={os.environ.get('RCLI_NODE_RANK', '0')}",
flush=True,
)
if role == "producer":
run_producer(args)
elif role == "colocated":
run_colocated(args) # baseline for head-to-head comparison
else:
run_consumer(args)


if __name__ == "__main__":
main()
62 changes: 62 additions & 0 deletions examples/disagg/run_qwen2.5_7b_eagle3_disagg.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env bash
# Disaggregated offline EAGLE3 (Qwen2.5-7B) across two pools sharing a mount.
#
# Drive both nodes at once with rcli (node 0 ingests features into the shared
# store; node 1 trains the draft from it):
#
# rcli exec --per-node <job> 'bash examples/disagg/run_qwen2.5_7b_eagle3_disagg.sh'
#
# Prereq: features already generated on node 0 at $FEATURES_DIR (see README).
set -euo pipefail

: "${SF_HOME:=/root/SpecForge}"
: "${TARGET_MODEL:=Qwen/Qwen2.5-7B-Instruct}"
: "${DRAFT_CONFIG:=$SF_HOME/configs/qwen2.5-7b-eagle3.json}"
: "${PROMPTS:=/root/disagg/prompts.jsonl}"
: "${FEATURES_DIR:=/root/disagg/features}"
: "${OUTPUT_DIR:=/root/disagg/out}"
: "${MAX_STEPS:=200}"
: "${NUM_EPOCHS:=10}"
: "${TTT_LENGTH:=7}"
: "${NPROC:=1}"
: "${CHAT_TEMPLATE:=qwen}"
: "${LEARNING_RATE:=5e-5}"
: "${CACHE_DIR:=$SF_HOME/cache}"
: "${DISAGG_STORE_ROOT:=/workspace/disagg_store}"
: "${DISAGG_MANIFEST:=/workspace/disagg_store/refs.json}"

# shared store + auth (both pools), and image/cache env
export DISAGG_STORE_ROOT DISAGG_MANIFEST
export DISAGG_STORE_ID="${DISAGG_STORE_ID:-eagle3-disagg}"
export DISAGG_AUTH_TOKEN="${DISAGG_AUTH_TOKEN:-disagg-secret}"
export FLASHINFER_DISABLE_VERSION_CHECK=1
export HOME=/root HF_HOME=/root/.cache/huggingface TRITON_CACHE_DIR=/root/.triton
export PYTHONPATH="$SF_HOME:$SF_HOME/scripts:${PYTHONPATH:-}"
cd "$SF_HOME"

COMMON=(
--target-model-path "$TARGET_MODEL"
--target-model-backend hf
--draft-model-config "$DRAFT_CONFIG"
--train-data-path "$PROMPTS"
--train-hidden-states-path "$FEATURES_DIR"
--output-dir "$OUTPUT_DIR"
--chat-template "$CHAT_TEMPLATE"
--cache-dir "$CACHE_DIR"
--attention-backend flex_attention
--ttt-length "$TTT_LENGTH"
--max-num-steps "$MAX_STEPS"
--num-epochs "$NUM_EPOCHS"
--batch-size 1
--learning-rate "$LEARNING_RATE"
--seed 0
)

if [ "${RCLI_NODE_RANK:-0}" = "0" ]; then
echo "[node0] PRODUCER: ingest $FEATURES_DIR -> $DISAGG_STORE_ROOT"
DISAGG_ROLE=producer python examples/disagg/run_disagg_eagle3.py "${COMMON[@]}"
else
echo "[node1] CONSUMER: train from shared store ($NPROC gpu)"
DISAGG_ROLE=consumer torchrun --standalone --nproc_per_node "$NPROC" \
examples/disagg/run_disagg_eagle3.py "${COMMON[@]}"
fi
Loading
Loading