Skip to content

[DataFlow runtime] Phase D — training managers (no_sync, full resume, checkpoint/eval)#637

Open
maocheng23 wants to merge 6 commits into
dataflow-up-16-zerocopyfrom
dataflow-up-29-training-managers
Open

[DataFlow runtime] Phase D — training managers (no_sync, full resume, checkpoint/eval)#637
maocheng23 wants to merge 6 commits into
dataflow-up-16-zerocopyfrom
dataflow-up-29-training-managers

Conversation

@maocheng23

@maocheng23 maocheng23 commented Jul 1, 2026

Copy link
Copy Markdown
Collaborator

Phase D — training managers (no_sync, full resume, checkpoint/eval)

Implements Phase D of the reconciled plan (#630, docs/roadmap/domain-refactor.md §D): brings the training loop to production parity. Stacked on #636 (Phase C). Four features plus best-checkpoint tracking, each listed below with its full API surface and how it is implemented.


1. Gradient accumulation with no_sync()

Where: specforge/runtime/training/trainer.py (TrainerCore), specforge/runtime/training/backend.py (FSDPTrainingBackend).

Function What it does / how
TrainerCore.train_step(batch, ctx) Scales loss by accumulation_steps, decides the optimizer boundary before backward (self._micro % accumulation_steps == 0) and passes it down, steps the optimizer only at the boundary.
TrainingBackend.backward(loss, *, is_boundary=True) (ABC) New keyword on the seam contract.
FSDPTrainingBackend.backward Wraps non-boundary micro-steps in self.module.no_sync() — the FSDP gradient reduce-scatter fires once per optimizer step, not once per micro-step. accumulation_steps=1 / unwrapped modules are byte-identical to before.

Deliberately removed: TrainerCore.eval_step and the mode plumbing — evaluation now consumes raw forward_loss outputs (see §4) because _result scalarizes away the per-position count tensors that correct acceptance-length aggregation needs.

2. Full resume

Where: specforge/runtime/training/backend.py, specforge/optimizer.py, specforge/training/trainer.py (domain Trainer), specforge/runtime/data_plane/feature_dataloader.py.

Function What it does / how
FSDPTrainingBackend.state_dict() Returns {model, optimizer, rng}. model is gathered under FullStateDictConfig(offload_to_cpu=True, rank0_only=True) (no per-rank GPU replica; non-rank0 gets {}). rng holds the CPU state plus the bound CUDA device's state only — checkpoints are independent of CUDA_VISIBLE_DEVICES layout.
FSDPTrainingBackend.load_state_dict(state) Restores model / optimizer / RNG; tolerates legacy list-valued CUDA RNG (picks the bound device's entry or fails fast with a clear message).
BF16Optimizer.state_dict/load_state_dict Now persist the fp32 master params (rank-local, shape-checked on load) alongside AdamW moments and the bundled LR scheduler — resume is numerically faithful instead of re-quantizing the master from bf16 weights. Missing fp32_params (older checkpoint) falls back to re-clone with one warning.
Trainer(resume_from=...) (domain) Restore order: read_resume_state → draft weights before the FSDP wrap (so the fp32 master is rebuilt from them) → optimizer/scheduler/RNG after prepare_modelstart_step/epoch/data position. Weight load is guarded: non-empty key overlap, zero unexpected keys, and saved-strategy-name match, else ValueError (a mismatched checkpoint can no longer silently "resume" from random init). Persists and validates dataset_size and accumulation_steps — drift fails fast.
FeatureDataLoader.seek(num_batches) Repositions the ref stream without materializing skipped features; raises when the skip exceeds the available batches. Mid-epoch position is persisted in samples (batch-size independent) and converted back to batches with a divisibility fail-fast.
TrainerController.fit() Early-returns when global_step >= max_steps before consuming any data (idempotent restart); the mid-epoch skip is driven by the live position counter so re-entering fit() never re-trains the prefix; the plain-iterable skip path validates the consumed count.

3. CheckpointManager

Where: specforge/training/checkpoint.py (new module).

Function What it does / how
CheckpointManager(output_dir, run_id, *, max_checkpoints=0, best_metric='eval/simulated_acc_len', best_min_delta=0.0) One manager per run; rehydrates the best record from disk on construction.
checkpoint_dir(step) Layout: {output_dir}/{run_id}-step{N}/ containing training_state.pt (rank0 shared payload) + training_state_rank{r}.pt per rank — required because FSDP use_orig_params shards AdamW moments; rank0-only optimizer state would corrupt every other rank on resume.
save(state, step, *, rank_state=None) All-rank makedirs; atomic writes (tmp + os.replace) for every payload; timeline rewind: rank0 deletes any on-disk step ≥ the one being saved (fork/rollback semantics, logged) and clears a best record inside the deleted range; cross-rank failure propagation: an all_gather of per-rank error strings makes one rank's FS error raise on all ranks instead of stranding the group in a barrier; rotation never deletes best_step or the just-written step, and its rmtree failures warn instead of killing the run.
is_better(eval_metrics) Rank0 computes score > best_score + best_min_delta and the verdict is broadcast inside the manager — callers can't accidentally diverge rank decisions before the collective-bearing save. Empty metrics → False.
update_best(step, eval_metrics) Rank0-only, atomic {run_id}.best_meta.json write + {run_id}-best repoint. Meta records run_id and metric; the loader ignores mismatched or corrupt meta with one warning. All pointer/meta artifacts are run_id-scoped so runs sharing an output_dir cannot clobber each other.
read_resume_state(path, *, map_location='cpu', require_full_state=True) The single resume reader: merges the shared payload with this rank's file, returned untouched under 'backend' (no key-rename seams). Fails fast on world-size mismatch and on optimizer-less checkpoints (require_full_state=False opts into weights+counters-only).
load(step=None) / latest_dir() latest is trusted only when it is a real symlink pointing at a complete checkpoint; a materialized (copied) latest is repaired on the next save; the fallback scans only complete step dirs, so a truncated dir from a mid-save preemption is invisible. Relative symlinks keep pointers valid after relocating output_dir.

4. Evaluator

Where: specforge/eval/evaluator.py (new module).

Function What it does / how
Evaluator.run(model, batches, strategy, ...) Aggregates per-position accept counts across the whole eval pass and across DP ranks before the geometric simulated_acc_len, so the metric is batch-size- and rank-invariant (never a mean of ratios). Emits eval/avg_loss, eval/avg_acc, eval/simulated_acc_len, plus aggregated legacy eval/acceptance_rate_i / eval/ploss_i when the strategy provides them.
Collective schedule Decided globally (one SUM over a fixed float64 slot vector, a MAX for the per-position length, one padded SUM iff any rank has per-position data): a rank with an empty or scalar-only shard issues identical collectives — hang-proof by construction. Reductions run on the rank's bound CUDA device; accumulation stays on-device (float64 counts, exact past 2^24) with one host sync.
Zero-batch pass Returns {} — never fabricated zeros, so an empty/drained eval set cannot poison best tracking.
Scalar strategies DFlash/Domino forward_loss now emit accuracy_denom (the accuracy's own denominator); the Evaluator weights scalar accuracy by it, restoring batch-size invariance for real scalar shapes (loss-token fallback documented).
TrainerController.evaluate(data) Feeds raw forward_loss outputs to the Evaluator under the same StepContext as training (Domino's schedule-dependent loss evaluates consistently); data=None joins the collectives with an empty shard. Eval entry itself is a rank0-broadcast decision, and eval metrics are logged and merged into last_metrics.

5. Best-checkpoint tracking

Any eval that beats the record (by more than best_min_delta) persists a checkpoint on demand — decoupled from save_interval cadence. The record survives restarts via meta rehydration, is rank-agreed by construction (§3 is_better), and rotation always protects it.

Wiring

Every build_* entry point in specforge/runtime/launch.py forwards resume_from / max_checkpoints. Disagg-online builders fail fast when resume_from is given without resume=True, and when the durable ack marker is ahead of the checkpoint (acked-but-not-checkpointed samples would be silently skipped — the reconcile docstring now states the accurate guarantee). New managers live in their S-homes (training/, eval/); the runtime seam imports them lazily.

Tests / gates

  • test_no_sync_equiv.py — 2-rank FSDP, distinct per-rank data: the accumulation path matches per-micro-step reduction, and deferral counting pins exactly accumulation_steps−1 deferrals per optimizer step.
  • test_checkpoint_manager.py (new) — atomicity (truncated dirs invisible), timeline rewind, run-scoped meta/pointers, corrupt/mismatched meta, glob-metachar run_ids, materialized-latest repair, is_better/update_best semantics incl. best_min_delta, read_resume_state branches, and a 2-process gloo save→per-rank-read gate (dynamic port).
  • test_checkpoint_resume.py — bit-for-bit weight round-trip with an asserted missing-key set; resume continuity vs an uninterrupted reference on distinct batches with an explicit LR-scheduler position assert; the production Trainer(resume_from=...) entrypoint end-to-end; fail-fast gates (strategy mismatch, zero-overlap, dataset-size drift, accumulation-steps drift, world-size mismatch); max_steps idempotent restart; re-entered fit() prefix safety.
  • test_evaluator_aggregation.py — aggregate-before-geometric-sum, batch-size invariance under realistic mismatched denominators, zero-batch → {}, float64 counts, 2-proc gloo ragged/empty-shard schedule (scoped note: covers the Evaluator's own collectives).
  • test_feature_dataloader.py / test_trainer.py — seek semantics incl. bounds + non-materialization, boundary-flag sequence on CPU, eval metrics reach the logger, checkpoint payload extras.

Validation

  • Full CPU-safe tests/test_runtime suite: 241 passed, 24 skipped (CUDA/multi-GPU gates), 1 xfailed locally after the review-fix pass.
  • 8×H200 pod: 239 OK (2 skip, 1 xfail) on the pre-fix revision; the GPU-gated tests (no_sync equivalence, launch equivalence) and multi-rank checkpoint paths need a pod re-run on the current head.

Review history

Three adversarial review rounds are folded in:

  1. Review fixes 1–2 (see commit trail): per-rank resume state, data-stream seek, DP-safe evaluator schedule, durable rank-agreed best tracking, bound-device collectives.
  2. Review fixes 3 (latest): checkpoint lifecycle hardening — atomic writes, rollback/fork rewind semantics, cross-rank failure propagation, run-scoped best/latest artifacts, optimizer-less-resume fail-fast, fp32-master persistence, bound-device RNG, rank0-only gather, scalar-accuracy denominator weighting, zero-batch eval guard, eval observability, max_steps idempotency, embed-filter freeze check, online resume knob coupling — plus a comment-slimming pass across all touched modules and the test hardening listed above.

🤖 Generated with Claude Code


How to review — commit by commit

The branch is structured as 5 self-contained behavioral commits, each with its own tests and a green CPU suite at that point. Reviewing commit-by-commit is the intended path:

# Commit Contract Suite at commit
1 6e1b739 no_sync accumulation — boundary decided before backward; FSDP reduces once per optimizer step 191 passed
2 eb5e263 Backend full state + CheckpointManager{model, optimizer(+fp32 masters), rng}; atomic, per-rank, run-scoped layout 210 passed
3 e96441c Full resume — domain restore ordering, seek-based data repositioning, drift fail-fasts 223 passed
4 fa882d1 Evaluator — aggregate-before-geometric-sum, DP-safe collective schedule, accuracy_denom weighting 240 passed
5 f512828 Best-checkpoint tracking in the training loop; docs 241 passed

Rebased onto dataflow-up-16-zerocopy after the B/C stack (#622#636) merged; adopts the build_control_plane_for_mode rename from the Phase C review fixes.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a batch-size invariant Evaluator for speculative decoding metrics, a CheckpointManager for handling checkpoint layout and rotation, and full training resume capabilities (including optimizer, scheduler, and RNG states) along with deferred FSDP gradient reduction using no_sync(). The review feedback highlights several critical issues to address: resolving a potential distributed deadlock in the evaluator when some ranks have empty shards, ensuring scalar metrics are properly reduced across ranks, fixing hardware portability issues with RNG state saving/restoring, handling potential filesystem errors during symlink creation and checkpoint rotation, and cleaning up temporary directories in tests.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread specforge/eval/evaluator.py Outdated
Comment on lines +81 to +85
totals = torch.tensor(
[loss_x_tokens, float(total_tokens)], dtype=torch.float64, device=device
)
dist.all_reduce(totals, op=dist.ReduceOp.SUM)
loss_x_tokens, total_tokens = float(totals[0]), float(totals[1])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For scalar strategies (e.g., DFlash/Domino), the scalar accuracy metrics (scalar_acc_sum and scalar_acc_n) are not reduced across data-parallel ranks. This causes each rank to compute a local average accuracy instead of a global one, leading to inconsistent evaluation metrics across ranks and incorrect logging on rank 0.

We should include scalar_acc_sum and scalar_acc_n in the totals tensor to reduce them globally.

Suggested change
totals = torch.tensor(
[loss_x_tokens, float(total_tokens)], dtype=torch.float64, device=device
)
dist.all_reduce(totals, op=dist.ReduceOp.SUM)
loss_x_tokens, total_tokens = float(totals[0]), float(totals[1])
totals = torch.tensor(
[loss_x_tokens, float(total_tokens), scalar_acc_sum, float(scalar_acc_n)],
dtype=torch.float64,
device=device
)
dist.all_reduce(totals, op=dist.ReduceOp.SUM)
loss_x_tokens, total_tokens = float(totals[0]), float(totals[1])
scalar_acc_sum, scalar_acc_n = float(totals[2]), int(totals[3])

Comment thread specforge/eval/evaluator.py Outdated
Comment on lines +86 to +88
if per_pos_correct is not None:
dist.all_reduce(per_pos_correct, op=dist.ReduceOp.SUM)
dist.all_reduce(per_pos_denom, op=dist.ReduceOp.SUM)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If some ranks have an empty evaluation dataset shard (or do not process any batches with per-position accuracy metrics), per_pos_correct will be None on those ranks. Since dist.all_reduce is a collective operation, calling it only on ranks where per_pos_correct is not None will cause a distributed deadlock (hang).

To prevent this, we should check if any rank has per_pos_correct, broadcast the tensor size from a rank that has it to initialize it on ranks where it is None, and then perform the all_reduce collectively on all ranks.

            has_per_pos = torch.tensor([1 if per_pos_correct is not None else 0], dtype=torch.int32, device=device)
            dist.all_reduce(has_per_pos, op=dist.ReduceOp.SUM)
            if has_per_pos.item() > 0:
                rank_has = torch.tensor([dist.get_rank() if per_pos_correct is not None else 999999], dtype=torch.int32, device=device)
                dist.all_reduce(rank_has, op=dist.ReduceOp.MIN)
                src_rank = int(rank_has.item())
                
                size_tensor = torch.tensor([per_pos_correct.numel() if per_pos_correct is not None else 0], dtype=torch.int32, device=device)
                dist.broadcast(size_tensor, src=src_rank)
                numel = int(size_tensor.item())
                
                if per_pos_correct is None:
                    per_pos_correct = torch.zeros(numel, dtype=torch.float32, device=device)
                    per_pos_denom = torch.zeros(numel, dtype=torch.float32, device=device)
                
                dist.all_reduce(per_pos_correct, op=dist.ReduceOp.SUM)
                dist.all_reduce(per_pos_denom, op=dist.ReduceOp.SUM)

Comment thread specforge/runtime/training/backend.py Outdated
Comment on lines +291 to +303
@staticmethod
def _rng_state() -> dict:
rng = {"cpu": torch.get_rng_state()}
if torch.cuda.is_available():
rng["cuda"] = torch.cuda.get_rng_state_all()
return rng

@staticmethod
def _set_rng_state(rng: dict) -> None:
if rng.get("cpu") is not None:
torch.set_rng_state(rng["cpu"])
if rng.get("cuda") is not None and torch.cuda.is_available():
torch.cuda.set_rng_state_all(rng["cuda"])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using torch.cuda.get_rng_state_all() and torch.cuda.set_rng_state_all() saves and restores the RNG states of all visible CUDA devices on the node. This creates a strict dependency on the hardware configuration: if a checkpoint is saved on an 8-GPU node and resumed on a 4-GPU or 1-GPU node (or vice versa), torch.cuda.set_rng_state_all will raise a RuntimeError due to a mismatch in the number of devices.

Since each distributed training process only utilizes its current active GPU (set via torch.cuda.set_device), we should only save and restore the RNG state of the current active device using torch.cuda.get_rng_state() and torch.cuda.set_rng_state(). This ensures hardware portability of checkpoints.

Suggested change
@staticmethod
def _rng_state() -> dict:
rng = {"cpu": torch.get_rng_state()}
if torch.cuda.is_available():
rng["cuda"] = torch.cuda.get_rng_state_all()
return rng
@staticmethod
def _set_rng_state(rng: dict) -> None:
if rng.get("cpu") is not None:
torch.set_rng_state(rng["cpu"])
if rng.get("cuda") is not None and torch.cuda.is_available():
torch.cuda.set_rng_state_all(rng["cuda"])
@staticmethod
def _rng_state() -> dict:
rng = {"cpu": torch.get_rng_state()}
if torch.cuda.is_available():
rng["cuda"] = torch.cuda.get_rng_state()
return rng
@staticmethod
def _set_rng_state(rng: dict) -> None:
if rng.get("cpu") is not None:
torch.set_rng_state(rng["cpu"])
if rng.get("cuda") is not None and torch.cuda.is_available():
torch.cuda.set_rng_state(rng["cuda"])

Comment thread specforge/eval/evaluator.py Outdated
Comment on lines +76 to +80
device = (
per_pos_correct.device
if per_pos_correct is not None
else torch.device("cuda", torch.cuda.current_device())
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If a distributed process group is initialized on CPU (e.g., using the Gloo backend for testing or CPU-only environments), calling torch.cuda.current_device() will raise an error.

We should check if CUDA is available before attempting to get the current CUDA device, falling back to CPU if it is not.

Suggested change
device = (
per_pos_correct.device
if per_pos_correct is not None
else torch.device("cuda", torch.cuda.current_device())
)
device = (
per_pos_correct.device
if per_pos_correct is not None
else torch.device("cuda", torch.cuda.current_device()) if torch.cuda.is_available() else torch.device("cpu")
)

Comment thread specforge/training/checkpoint.py Outdated
Comment on lines +116 to +123
def _point(self, name: str, ckpt_dir: str) -> None:
link = os.path.join(self.output_dir, name)
if os.path.islink(link) or os.path.exists(link):
try:
os.remove(link)
except OSError:
return # a non-symlink collision: leave it, the dir is source of truth
os.symlink(os.path.abspath(ckpt_dir), link)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating a symlink via os.symlink can fail due to filesystem limitations, OS-specific restrictions (e.g., lack of symlink privileges on Windows), or permission issues. Since symlinks are convenience pointers (latest and best), failing to create them should not crash an active training run.

We should wrap os.symlink in a try-except OSError block to ensure that any failure to create the symlink is gracefully handled (e.g., logged as a warning) rather than crashing the training process.

Suggested change
def _point(self, name: str, ckpt_dir: str) -> None:
link = os.path.join(self.output_dir, name)
if os.path.islink(link) or os.path.exists(link):
try:
os.remove(link)
except OSError:
return # a non-symlink collision: leave it, the dir is source of truth
os.symlink(os.path.abspath(ckpt_dir), link)
def _point(self, name: str, ckpt_dir: str) -> None:
link = os.path.join(self.output_dir, name)
if os.path.islink(link) or os.path.exists(link):
try:
os.remove(link)
except OSError:
return # a non-symlink collision: leave it, the dir is source of truth
try:
os.symlink(os.path.abspath(ckpt_dir), link)
except OSError:
pass # gracefully ignore symlink creation failures to avoid crashing the run

Comment thread specforge/training/checkpoint.py Outdated
Comment on lines +125 to +134
def _rotate(self) -> None:
if self.max_checkpoints <= 0:
return
dirs = sorted(self._all_checkpoints(), key=lambda kv: kv[0])
for step, path in dirs[: -self.max_checkpoints]:
if step == self.best_step:
continue # never rotate away the tracked best
for f in glob.glob(os.path.join(path, "*")):
os.remove(f)
os.rmdir(path)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If deleting an old checkpoint file or directory fails (e.g., due to file locks, transient filesystem issues, or permission errors), the raised OSError will crash the active training run. Since checkpoint rotation is a non-critical cleanup operation, failures should be handled gracefully.

We should wrap the file and directory removal calls in a try-except OSError block to prevent cleanup failures from crashing the training.

Suggested change
def _rotate(self) -> None:
if self.max_checkpoints <= 0:
return
dirs = sorted(self._all_checkpoints(), key=lambda kv: kv[0])
for step, path in dirs[: -self.max_checkpoints]:
if step == self.best_step:
continue # never rotate away the tracked best
for f in glob.glob(os.path.join(path, "*")):
os.remove(f)
os.rmdir(path)
def _rotate(self) -> None:
if self.max_checkpoints <= 0:
return
dirs = sorted(self._all_checkpoints(), key=lambda kv: kv[0])
for step, path in dirs[: -self.max_checkpoints]:
if step == self.best_step:
continue # never rotate away the tracked best
try:
for f in glob.glob(os.path.join(path, "*")):
os.remove(f)
os.rmdir(path)
except OSError:
pass # ignore cleanup failures to avoid crashing the run

from specforge.runtime.training.trainer import TrainerController, TrainerCore

TTT, BS, TOTAL, CUT = 3, 2, 6, 3
workdir = tempfile.mkdtemp(prefix="ckpt_continuity_")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The temporary directory created with tempfile.mkdtemp is never cleaned up, which leads to disk space accumulation (resource leaks) on the test runner or developer machine.

We should use self.addCleanup with shutil.rmtree to ensure the directory is cleaned up after the test finishes.

Suggested change
workdir = tempfile.mkdtemp(prefix="ckpt_continuity_")
workdir = tempfile.mkdtemp(prefix="ckpt_continuity_")
import shutil
self.addCleanup(shutil.rmtree, workdir)

def test_rotation_pointers_and_best(self):
from specforge.training.checkpoint import CheckpointManager

out = tempfile.mkdtemp(prefix="ckpt_mgr_")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The temporary directory created with tempfile.mkdtemp is never cleaned up, leading to disk space accumulation.

We should use self.addCleanup with shutil.rmtree to ensure the directory is cleaned up after the test finishes.

Suggested change
out = tempfile.mkdtemp(prefix="ckpt_mgr_")
out = tempfile.mkdtemp(prefix="ckpt_mgr_")
import shutil
self.addCleanup(shutil.rmtree, out)

def test_no_sync_matches_per_step_reduction(self):
import torch.multiprocessing as mp

workdir = tempfile.mkdtemp(prefix="no_sync_equiv_")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The temporary directory created with tempfile.mkdtemp is never cleaned up, leading to disk space accumulation.

We should use self.addCleanup with shutil.rmtree to ensure the directory is cleaned up after the test finishes.

Suggested change
workdir = tempfile.mkdtemp(prefix="no_sync_equiv_")
workdir = tempfile.mkdtemp(prefix="no_sync_equiv_")
import shutil
self.addCleanup(shutil.rmtree, workdir)

maocheng23 added a commit that referenced this pull request Jul 1, 2026
…fe Evaluator, durable best tracking

Post-review fixes for #637 (adversarial review vs the #630 roadmap):

Multi-rank resume correctness:
- CheckpointManager.save now writes each rank's optimizer/RNG to its own
  training_state_rank{r}.pt beside the rank0 shared payload (draft
  weights + counters + world_size); under FSDP use_orig_params the AdamW
  moments live on rank-local shard views, so persisting only rank0's
  copy and restoring it everywhere corrupted the other ranks' moments
  and collapsed their RNG streams. read_resume_state hands each rank
  back its own state and fails fast on a world-size mismatch (and on
  legacy single-file checkpoints at world_size>1).

Resume repositions the data stream (plan.md G1 seek()-equivalent):
- TrainerController tracks epoch_batch, persists it, and skips the
  consumed prefix of the interrupted epoch on resume (via the new
  FeatureDataLoader.seek — no feature materialization — or islice for
  plain iterables). The domain Trainer threads it for refs-mode runs;
  the online queue path keeps control-plane skip_ids reconciliation.
- test_resume_loss_curve_continuity now trains on DISTINCT batches, so
  resuming on the wrong data cannot pass (the old fixed-batch form was
  structurally blind to the data position).

Evaluator DP correctness:
- The collective schedule is decided globally (SUM of scalar sums, MAX
  of the per-position length, then ONE stacked count reduce) so a rank
  with an empty or scalar-only shard issues the same collectives as its
  peers — no NCCL desync. Scalar accuracy (DFlash/Domino) is now
  reduced across ranks like everything else. Collectives use the local
  device via specforge.utils.get_local_device (CPU for gloo).
  Accumulation stays on-device (one host sync after the loop), and
  eval/per_position_acc is reported alongside the folded acc-len.
- New 2-process gloo gate: cross-rank scalar reduction + ragged-shard
  schedule symmetry (test_evaluator_aggregation).

Durable, decoupled best tracking:
- CheckpointManager rehydrates best_score/best_step from best_meta.json
  (now also carrying "score") on construction, so a restarted process
  neither rotates away the on-disk best nor lets a worse score
  overwrite it; update_best is split into score()/is_better().
- fit() tracks the best on EVERY eval (when checkpointing is enabled),
  persisting a checkpoint on demand when the best eval lands off the
  save cadence — previously best only fired when eval_interval and
  save_interval coincided on the same step.

Surface + cleanup:
- launch: every builder now forwards resume_from/max_checkpoints, so
  resume and rotation are reachable from the build_* entry points.
- TrainerController: drop the max_checkpoints/checkpoint_manager dual
  config (inject a configured manager; the domain Trainer does);
  remove the dead eval_step + mode plumbing (evaluate goes through
  Evaluator on raw forward_loss; validate_batch still runs inside every
  strategy's forward_loss); drop the did_eval/last_eval_metrics
  threading.
- Trainer._load_resume_state removed: CheckpointManager.read_resume_state
  is the single checkpoint reader; the loaded dict is dropped after the
  weight copy instead of living through the FSDP wrap.
- CheckpointManager._rotate uses shutil.rmtree; symlink creation is
  guarded for filesystems without symlink support (dirs + best_meta.json
  stay the source of truth); save_checkpoint filters draft weights on
  rank0 only.
- test_no_sync_equiv now also counts no_sync deferrals (exactly
  accumulation_steps-1 per optimizer step) — the roadmap's "one
  all-reduce per optimizer step" gate, not just weight equality.
- DESIGN.md updated (per-rank layout, seek, best semantics; fixed the
  inverted no_sync sentence).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
maocheng23 added a commit that referenced this pull request Jul 1, 2026
E1's evaluator/best-tracking half landed with Phase D (#637); the
EvalConfig/EvalCache half is the E1 PR, with cache wiring into the run
surface deferred to Phase E's config+CLI (noted explicitly).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@maocheng23

Copy link
Copy Markdown
Collaborator Author

One more commit from the E0/E1 review pass, placed here because it fixes this PR's code: c71560f — the Evaluator's scalar-accuracy path (DFlash/Domino) is now token-weighted (sum of correct over sum of tokens) instead of a mean of per-batch means, making it batch-size invariant like the per-position path (new gate: test_scalar_accuracy_is_batch_size_invariant; the DP gate's expectation updated to the weighted semantics); and CheckpointManager's best/latest symlinks now use relative targets so resume_from=<dir>/best survives relocating the output directory. Revalidated: full tests/test_runtime = 240 OK (2 skip, 1 xfail) on 8×H200.

Base automatically changed from dataflow-up-28-colocated-lightweight to dataflow-up-16-zerocopy July 3, 2026 02:38
@jiapingW jiapingW self-requested a review July 3, 2026 02:39
maocheng23 and others added 5 commits July 4, 2026 00:15
…backward

TrainerCore now computes the optimizer-step boundary BEFORE calling backward
and passes it as backward(loss, is_boundary=...). FSDPTrainingBackend runs
non-boundary micro-steps under no_sync() so the gradient reduce-scatter fires
once per optimizer step instead of once per micro-batch — identical math,
one collective per step. GPU equivalence gate: test_no_sync_equiv.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…, per-rank, run-scoped)

FSDPTrainingBackend.state_dict now returns the full {model, optimizer, rng}
training state: model gathered rank0-only with CPU offload, optimizer/RNG kept
rank-local (single bound-device CUDA state). BF16Optimizer persists its fp32
masters so resume does not re-quantize from bf16. CheckpointManager owns the
on-disk layout: atomic run-scoped step dirs, per-rank state files, latest/best
pointers with rotation. Eagle3's filter drops the embedding only when frozen.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
… drift fail-fasts

Trainer(resume_from=...) restores draft weights BEFORE the FSDP wrap, then the
rank-local optimizer/RNG shard after, and fail-fasts on strategy / dataset_size /
accumulation_steps / weight-key drift. TrainerController tracks the live epoch
position (batches + batch-size-independent samples), skips the consumed prefix
via seek()/islice on resume or fit() re-entry, and returns early at max_steps.
save_checkpoint goes through CheckpointManager (shared payload rank0-only,
per-rank optimizer/RNG files). launch.py forwards the resume knobs and
validates online resume; its docstring slimming rides here.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…dule, accuracy_denom weighting

TrainerCore.eval_step is gone: it scalarized away the per-position count
tensors correct acc-len aggregation needs. TrainerController.evaluate now runs
Evaluator.run over raw forward_loss outputs (weighted global counts first,
geometric acc-len sum after DP reduction) at the live StepContext. fit()
broadcasts rank0's eval-enabled verdict so no rank enters the evaluator's
collectives alone, and logs eval/* metrics every eval step. DFlash/Domino
models now surface accuracy_denom so eval weights accuracy exactly.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
fit() now tracks the best eval score: is_better runs only on non-empty,
DP-reduced eval metrics (its guard is rank-identical; the rank0 verdict
broadcast lives inside the manager), a best step gets a checkpoint persisted
on demand even when the eval and save intervals never coincide, and
update_best records the pointer after the save. DESIGN/ARCHITECTURE docs
updated for the Phase D lifecycle.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@maocheng23 maocheng23 marked this pull request as ready for review July 4, 2026 07:24
@maocheng23 maocheng23 requested a review from shuaills as a code owner July 4, 2026 07:24
@maocheng23 maocheng23 force-pushed the dataflow-up-29-training-managers branch from 3ef9f34 to f512828 Compare July 4, 2026 07:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant