Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions scripts/train_eagle3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# NOTE: core EAGLE3 training is being migrated to the DataFlow runtime launcher
# scripts/train_eagle3_dataflow.py (offline + online; validated old-vs-new on 7B).
# That launcher does not YET cover the following, so this script remains the path
# for them: VLM (--is-vlm), USP sequence parallelism (--attention-backend usp),
# the eval loop (--eval-*-path), --resume, and experiment trackers (--report-to).
import argparse
import hashlib
import math
Expand Down
116 changes: 116 additions & 0 deletions scripts/train_eagle3_dataflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Thin launcher: offline EAGLE3 training through the SpecForge DataFlow runtime.

This script is a *launcher* (M3): it builds models + optimizer, hands them to the
runtime, and runs ``TrainerController.fit``. No training logic lives here — the
loop, loss, projection, checkpoint, and eval all live in ``specforge.runtime``.

Reuses the existing model/data builders from ``scripts.train_eagle3`` so model
construction stays DRY; only the *orchestration* moves behind the runtime.

Example (offline):
torchrun --standalone --nproc_per_node 1 scripts/train_eagle3_dataflow.py \
--target-model-path <hf-model> --draft-model-config configs/llama3-8B-eagle3.json \
--train-data-path <prompts.jsonl> --train-hidden-states-path <features_dir> \
--output-dir ./output --max-num-steps 20
"""

from accelerate.utils import set_seed

# reuse existing builders so model construction is not duplicated
from train_eagle3 import (
build_dataloaders,
build_draft_model,
build_target_model,
parse_args,
)

from specforge.distributed import destroy_distributed, init_distributed
from specforge.optimizer import BF16Optimizer
from specforge.runtime.launch import build_offline_eagle3_runtime


def main():
parser, args = parse_args()
# parse_args() does not derive target_batch_size (train_eagle3.main computes
# it inline before building dataloaders); the offline runtime builder and
# build_dataloaders both read it, so derive it here too.
args.target_batch_size = args.tp_size * args.batch_size

# TODO(dataflow-launcher parity with scripts/train_eagle3.py): this launcher
# covers core EAGLE3 training (offline + online: loss / projection / FSDP /
# TP / grad-accum / checkpoint), validated old-vs-new. The following
# train_eagle3.py features are NOT yet wired here and still require the
# legacy script:
# - VLM / multimodal targets (--is-vlm, QwenVLOnlineEagle3Model)
# - USP sequence parallelism (--attention-backend usp -> process_data_usp;
# this path uses OfflineEagle3Dataset.process_data, no per-rank seq shard)
# - eval loop (--eval-data-path / --eval-hidden-states-path)
# - resume from checkpoint (--resume)
# - experiment trackers (--report-to wandb / swanlab / tensorboard)
# - online multi-epoch re-rollout (online runs a single consume-once pass)
set_seed(args.seed)
init_distributed(
timeout=args.dist_timeout,
tp_size=args.tp_size,
sp_ring_size=args.sp_ring_size,
sp_ulysses_size=args.sp_ulysses_size,
)
if args.train_hidden_states_path is None:
raise SystemExit(
"train_eagle3_dataflow currently wires the OFFLINE path; pass "
"--train-hidden-states-path. (Online wiring composes RolloutWorker + "
"SGLangAdapter over the same control/data plane.)"
)

draft_config, draft_model, _ckpt, _resume = build_draft_model(args)
target_head, _ = build_target_model(args, draft_config, is_online=False)
# vocab mapping is produced from the prompt dataset exactly as today
_train, vocab_mapping_path, _eval = build_dataloaders(args, draft_config)
draft_model.load_vocab_mapping(vocab_mapping_path)

from specforge import OnlineEagle3Model

eagle3_model = OnlineEagle3Model(
draft_model=draft_model,
length=args.ttt_length,
attention_backend=args.attention_backend,
lk_loss_type=args.lk_loss_type,
kl_scale=args.kl_scale,
kl_decay=args.kl_decay,
).cuda()

# built AFTER FSDP-wrap (inside the runtime) over the wrapped inner draft
def optimizer_factory(draft_module):
return BF16Optimizer(
draft_module,
lr=args.learning_rate,
max_grad_norm=args.max_grad_norm,
warmup_ratio=args.warmup_ratio,
total_steps=args.total_steps or 10_000,
)

trainer, loader = build_offline_eagle3_runtime(
hidden_states_path=args.train_hidden_states_path,
eagle3_model=eagle3_model,
target_head=target_head,
optimizer_factory=optimizer_factory,
run_id="eagle3-offline",
output_dir=args.output_dir,
ttt_length=args.ttt_length,
max_len=args.max_length,
batch_size=args.target_batch_size,
accumulation_steps=args.draft_accumulation_steps,
num_epochs=args.num_epochs,
max_steps=args.max_num_steps,
save_interval=args.save_interval,
tp_size=args.tp_size,
sp_ulysses_size=args.sp_ulysses_size,
sp_ring_size=args.sp_ring_size,
logger=lambda m, s: print(f"step {s}: {m}", flush=True),
)
trainer.fit(loader)
destroy_distributed()


if __name__ == "__main__":
main()
28 changes: 21 additions & 7 deletions specforge/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,27 @@ def init_distributed(

def destroy_distributed():
global _TP_GROUP, _DP_GROUP, _SP_ULYSSES_GROUP, _SP_RING_GROUP, _DRAFT_DP_GROUP
dist.destroy_process_group(_TP_GROUP)
dist.destroy_process_group(_DP_GROUP)
dist.destroy_process_group(_SP_ULYSSES_GROUP)
dist.destroy_process_group(_SP_RING_GROUP)
dist.destroy_process_group(_DRAFT_DP_GROUP)
dist.destroy_process_group(_DRAFT_SP_GROUP)
dist.destroy_process_group()
# Teardown must never crash the process: a group may be None (e.g. a trivial
# single-rank world) or already destroyed. Destroy each defensively so a
# successful run does not exit non-zero on cleanup.
for group in (
_TP_GROUP,
_DP_GROUP,
_SP_ULYSSES_GROUP,
_SP_RING_GROUP,
_DRAFT_DP_GROUP,
_DRAFT_SP_GROUP,
):
if group is None:
continue
try:
dist.destroy_process_group(group)
except Exception:
pass
try:
dist.destroy_process_group()
except Exception:
pass


def shard_tensor(
Expand Down
156 changes: 156 additions & 0 deletions specforge/runtime/ARCHITECTURE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# SpecForge DataFlow Runtime — Architecture (M1–M4)

The runtime moves SpecForge from a trainer-centered god-script to explicit
contracts across **two planes**. This is the cross-plane map; each plane also has
its own design note (see "Per-plane internals" below).

## Plane responsibilities

- **Contracts** — the stdlib-only data records every plane exchanges (`PromptTask`, `SampleRef`, `FeatureSpec`, `FeatureHandle`, `TrainBatch`) plus `assert_no_tensors`, which enforces the no-tensor boundary. Control-plane records carry metadata only; `TrainBatch` is the sole tensor carrier and lives only on the trainer side.
- **Control plane** — `DataFlowController` (a passive coordinator with no run loop), `MetadataStore` (commit dedup + the single durable ack transaction), `SampleRefQueue` (lease/ack/fail transport), and `TrainLease`. Moves metadata only; every record-accepting entrypoint runs `assert_no_tensors`.
- **Data plane** — `FeatureStore`/`LocalFeatureStore` is the only holder of tensors, addressed by metadata-only `SampleRef`. `FeatureDataLoader` is the bridge that materializes refs + store into collated `TrainBatch`es; `OfflineManifestReader` turns precomputed `.ckpt` files into in-place `file://` refs.
- **Inference (compute)** — `RolloutWorker` + `SGLangAdapter` extract features from the target engine and commit only `SampleRef` metadata; `SGLangAdapter._project_target` is the sole target to draft projection site and `verify_capture` is the loud pre-`put` validator.
- **Training (compute)** — `TrainerController` -> `TrainerCore` -> `DraftTrainStrategy` + `FSDPTrainingBackend` turn `TrainBatch`es into optimizer steps and checkpoints; the strategy owns projection/loss, the core is branch-free.

## End-to-end flow

**Online:** `RolloutWorker.run_once` leases prompts (`lease_prompt_tasks`), calls `generate_features` (which drives `generate_eagle3_data`), runs `verify_capture`, then `FeatureStore.put` writes tensors **directly to the data plane**. Only the resulting `SampleRef` metadata goes to the controller via `commit_samples`, which dedups through `MetadataStore.commit_sample` and enqueues fresh refs onto `SampleRefQueue`.

**Offline:** `OfflineManifestReader.read()` emits in-place `file://` `SampleRef`s (no tensor copy) and the launcher calls `enqueue_offline_refs`, which dedups and enqueues onto the **same** `SampleRefQueue`.

**Convergence + training:** Both paths converge at `SampleRef` on one queue, so the trainer has no online/offline branch. `TrainerController.fit` drives `for batch in loader`; the `FeatureDataLoader` leases refs through a `TrainLease` (`lease_train_refs` via the controller) and fetches the actual tensors **directly** from the `FeatureStore` (`get`/`release` with clone-on-fetch). `TrainerCore.train_step` runs forward/loss/backward and steps the optimizer at the grad-accum boundary; at that boundary `ack_fn` calls `ack_train_refs`, which records the durable ack transaction (`record_train_ack`) **before** releasing the queue lease.

## System map

Solid edges = control / metadata flow. Dashed edges = tensor flow (data plane).
The `DataFlowController` is **passive**: callers point *into* it, and its only
outbound edges go to its own `SampleRefQueue` and `MetadataStore`. Tensors never
cross the controller.

```mermaid
flowchart TD
classDef control fill:#e8f0fe,stroke:#3b6fd6,color:#0b2e6b;
classDef data fill:#fdeede,stroke:#d6893b,color:#6b3a0b;
classDef compute fill:#e6f6ea,stroke:#3bb061,color:#0b4a22;

subgraph COMPUTE[compute autonomous loops]
RW[RolloutWorker run_once loop]
SG[SGLangAdapter generate_features]
TGT[Eagle3TargetModel generate_eagle3_data]
TR[TrainerController fit loop]
CORE[TrainerCore train_step]
STRAT[Eagle3TrainStrategy forward_loss]
BE[FSDPTrainingBackend backward step]
LOADER[FeatureDataLoader iter]
OFF[OfflineManifestReader read]
end

subgraph CONTROL[control plane metadata only]
CTRL[DataFlowController passive coordinator]
QUEUE[SampleRefQueue lease ack fail]
MS[MetadataStore commit dedup durable ack]
LEASE[TrainLease get ack fail]
end

subgraph DATA[data plane tensors only]
STORE[FeatureStore LocalFeatureStore]
end

RW -->|register_rollout_worker| CTRL
RW -->|lease_prompt_tasks| CTRL
RW -->|generate_features| SG
SG -->|generate_eagle3_data| TGT
RW -.->|put| STORE
RW -->|commit_samples| CTRL
RW -->|fail_prompt_tasks| CTRL

OFF -->|enqueue_offline_refs| CTRL

CTRL -->|commit_sample| MS
CTRL -->|record_train_ack| MS
CTRL -->|get_committed| MS
CTRL -->|put fresh refs| QUEUE
CTRL -->|get| QUEUE
CTRL -->|ack| QUEUE
CTRL -->|fail| QUEUE

TR -->|train_step| CORE
CORE -->|forward_loss| STRAT
CORE -->|backward step| BE
TR -->|for batch in loader| LOADER
LOADER -->|lease_train_refs get| LEASE
LEASE -->|lease_train_refs| CTRL
LEASE -->|ack_train_refs| CTRL
LEASE -->|fail_refs| CTRL
LOADER -.->|get release| STORE
TR -->|ack_fn ack_train_refs| CTRL

class RW,SG,TGT,TR,CORE,STRAT,BE,LOADER,OFF compute;
class CTRL,QUEUE,MS,LEASE control;
class STORE data;
```

## Endpoint reference

| Caller | Endpoint called | Plane | Purpose |
|---|---|---|---|
| RolloutWorker | register_rollout_worker | control | Register worker, obtain authoritative worker_id (no-tensor guard on info) |
| RolloutWorker | lease_prompt_tasks | control | Pop up to max_tasks pending PromptTasks, mark leased to worker_id |
| RolloutWorker | generate_features | compute | Ask the FeatureSource (SGLangAdapter) for one feature dict per task |
| SGLangAdapter | generate_eagle3_data | compute | Run the target engine's batched forward to extract hidden_states/target |
| RolloutWorker | put | data | Persist verified feature tensors directly to FeatureStore, get back a SampleRef |
| RolloutWorker | abort | data | Clean up a partial/failed write so no corrupt sample is left |
| RolloutWorker | commit_samples | control | Commit metadata-only SampleRefs; dedup + enqueue fresh refs |
| RolloutWorker | fail_prompt_tasks | control | Release failed prompt leases (retryable or terminal) so none are stranded |
| OfflineManifestReader | enqueue_offline_refs | control | Offline ingest: dedup + enqueue file:// SampleRefs onto the same queue |
| DataFlowController | commit_sample | control | Dedup committed samples (True=new, False=duplicate) |
| DataFlowController | record_train_ack | control | Persist the single durable ack transaction before releasing leases |
| DataFlowController | get_committed | control | Resolve acked/failed sample_ids back to full SampleRef objects |
| DataFlowController | put | control | Enqueue freshly committed refs onto the shared SampleRefQueue |
| DataFlowController | get | control | Serve train-side leases from the SampleRefQueue |
| DataFlowController | ack | control | Release queue leases after the durable ack transaction is recorded |
| DataFlowController | fail | control | Route train-side ref failures through the queue (retryable flag) |
| TrainLease | lease_train_refs | control | Loader's get(): lease train refs via the controller, not a raw queue |
| TrainLease | ack_train_refs | control | ack(): record durable transaction + release leases via the controller |
| TrainLease | fail_refs | control | fail(): route ref failures through the controller |
| FeatureDataLoader | lease_train_refs (via TrainLease.get) | control | Lease a batch of refs from the stream |
| FeatureDataLoader | get | data | Fetch a sample's tensors + lease FeatureHandle directly from the store |
| FeatureDataLoader | release | data | Release the lease immediately after clone-on-fetch so prefetch can't race |
| TrainerController | for batch in loader (__iter__) | compute | Drive the loader to yield collated TrainBatch objects |
| TrainerController | train_step | compute | Run each micro-batch; read optimizer_stepped boundary signal |
| TrainerController | eval_step | compute | Run eval batches and aggregate metrics |
| TrainerController | ack_fn -> ack_train_refs | control | Close the ack loop: ack consumed sample_ids at the optimizer-step boundary |
| TrainerCore | forward_loss | compute | Delegate model-specific forward + loss to the strategy |
| TrainerCore | backward | compute | Run backward on the accumulation-scaled loss each micro-step |
| TrainerCore | step | compute | Optimizer step + distributed grad-norm reduction at the accum boundary |

## Autonomy: loops + a passive coordinator

## Autonomous loops + one passive coordinator

This is **not** a master/orchestrator that calls into sub-parts, and it is **not** a set of fully independent processes. It is a small number of **autonomous producer/consumer loops** coordinated by **one passive shared component**, the `DataFlowController`.

- The **producer loop** is `RolloutWorker.run_once`, which runs on its own and *calls into* the controller (`lease_prompt_tasks`, `commit_samples`, `fail_prompt_tasks`). The controller never calls the worker.
- The **consumer loop** is `TrainerController.fit`, which drives `for batch in loader`; the loader *calls into* the controller through `TrainLease` (`lease_train_refs` / `ack_train_refs` / `fail_refs`). The controller never calls the trainer.
- The `DataFlowController` has **no run loop**. The only edges out of it go into its **own** `SampleRefQueue` and `MetadataStore`. Workers and the trainer point INTO the controller; that is what makes it passive.

Tensors reinforce this: they **never** flow through the controller. `RolloutWorker` calls `FeatureStore.put` directly and `FeatureDataLoader` calls `FeatureStore.get`/`release` directly. Only metadata-bearing `SampleRef`s cross the control plane.

## Why this makes disaggregation mechanical

Because the loops are autonomous and only the coordinator is shared, moving components across nodes is a **swap, not a rewrite**:

- **Durable backend swap:** all recovery-critical state sits behind the `MetadataStore` ABC (commit dedup + the atomic `record_train_ack` marker). A SQLite/Redis/DB backend is a new subclass injected into the controller — no controller rewrite, and `assert_no_tensors` keeps the seam importable without torch.
- **`TrainLease` indirection:** the trainer never holds a raw in-process queue. It routes every `get`/`ack`/`fail` through the controller, so a cross-node trainer is a drop-in substitution and the durable ack transaction is always recorded.
- **`partition_key` seam:** `SampleRefQueue.put`/`get` already accept a `partition_key` (currently accepted but ignored, single partition), reserving the per-DP-rank partitioning needed for a sharded/disaggregated queue without an API change.
- **Online/offline convergence:** because `commit_samples` and `enqueue_offline_refs` land on the same queue and the trainer path is branch-free, the same consumer loop serves a disaggregated rollout fleet or a static offline manifest unchanged.

## Per-plane internals

Each plane carries its own design note (landed with that plane's PR):

- `contracts.py` / `CONTRACTS.md` — shared metadata records + `assert_no_tensors`
- `data_plane/DESIGN.md` — storage, queue, loader, lifecycle
- `control_plane/DESIGN.md` — controller, metadata store, lease/durability
- `inference/DESIGN.md` — rollout worker, capture, sglang seam
- `training/DESIGN.md` — trainer core, strategy, FSDP backend
Loading
Loading