[DataFlow runtime] Phase D — training managers (no_sync, full resume, checkpoint/eval)#637
[DataFlow runtime] Phase D — training managers (no_sync, full resume, checkpoint/eval)#637maocheng23 wants to merge 6 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a batch-size invariant Evaluator for speculative decoding metrics, a CheckpointManager for handling checkpoint layout and rotation, and full training resume capabilities (including optimizer, scheduler, and RNG states) along with deferred FSDP gradient reduction using no_sync(). The review feedback highlights several critical issues to address: resolving a potential distributed deadlock in the evaluator when some ranks have empty shards, ensuring scalar metrics are properly reduced across ranks, fixing hardware portability issues with RNG state saving/restoring, handling potential filesystem errors during symlink creation and checkpoint rotation, and cleaning up temporary directories in tests.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| totals = torch.tensor( | ||
| [loss_x_tokens, float(total_tokens)], dtype=torch.float64, device=device | ||
| ) | ||
| dist.all_reduce(totals, op=dist.ReduceOp.SUM) | ||
| loss_x_tokens, total_tokens = float(totals[0]), float(totals[1]) |
There was a problem hiding this comment.
For scalar strategies (e.g., DFlash/Domino), the scalar accuracy metrics (scalar_acc_sum and scalar_acc_n) are not reduced across data-parallel ranks. This causes each rank to compute a local average accuracy instead of a global one, leading to inconsistent evaluation metrics across ranks and incorrect logging on rank 0.
We should include scalar_acc_sum and scalar_acc_n in the totals tensor to reduce them globally.
| totals = torch.tensor( | |
| [loss_x_tokens, float(total_tokens)], dtype=torch.float64, device=device | |
| ) | |
| dist.all_reduce(totals, op=dist.ReduceOp.SUM) | |
| loss_x_tokens, total_tokens = float(totals[0]), float(totals[1]) | |
| totals = torch.tensor( | |
| [loss_x_tokens, float(total_tokens), scalar_acc_sum, float(scalar_acc_n)], | |
| dtype=torch.float64, | |
| device=device | |
| ) | |
| dist.all_reduce(totals, op=dist.ReduceOp.SUM) | |
| loss_x_tokens, total_tokens = float(totals[0]), float(totals[1]) | |
| scalar_acc_sum, scalar_acc_n = float(totals[2]), int(totals[3]) |
| if per_pos_correct is not None: | ||
| dist.all_reduce(per_pos_correct, op=dist.ReduceOp.SUM) | ||
| dist.all_reduce(per_pos_denom, op=dist.ReduceOp.SUM) |
There was a problem hiding this comment.
If some ranks have an empty evaluation dataset shard (or do not process any batches with per-position accuracy metrics), per_pos_correct will be None on those ranks. Since dist.all_reduce is a collective operation, calling it only on ranks where per_pos_correct is not None will cause a distributed deadlock (hang).
To prevent this, we should check if any rank has per_pos_correct, broadcast the tensor size from a rank that has it to initialize it on ranks where it is None, and then perform the all_reduce collectively on all ranks.
has_per_pos = torch.tensor([1 if per_pos_correct is not None else 0], dtype=torch.int32, device=device)
dist.all_reduce(has_per_pos, op=dist.ReduceOp.SUM)
if has_per_pos.item() > 0:
rank_has = torch.tensor([dist.get_rank() if per_pos_correct is not None else 999999], dtype=torch.int32, device=device)
dist.all_reduce(rank_has, op=dist.ReduceOp.MIN)
src_rank = int(rank_has.item())
size_tensor = torch.tensor([per_pos_correct.numel() if per_pos_correct is not None else 0], dtype=torch.int32, device=device)
dist.broadcast(size_tensor, src=src_rank)
numel = int(size_tensor.item())
if per_pos_correct is None:
per_pos_correct = torch.zeros(numel, dtype=torch.float32, device=device)
per_pos_denom = torch.zeros(numel, dtype=torch.float32, device=device)
dist.all_reduce(per_pos_correct, op=dist.ReduceOp.SUM)
dist.all_reduce(per_pos_denom, op=dist.ReduceOp.SUM)| @staticmethod | ||
| def _rng_state() -> dict: | ||
| rng = {"cpu": torch.get_rng_state()} | ||
| if torch.cuda.is_available(): | ||
| rng["cuda"] = torch.cuda.get_rng_state_all() | ||
| return rng | ||
|
|
||
| @staticmethod | ||
| def _set_rng_state(rng: dict) -> None: | ||
| if rng.get("cpu") is not None: | ||
| torch.set_rng_state(rng["cpu"]) | ||
| if rng.get("cuda") is not None and torch.cuda.is_available(): | ||
| torch.cuda.set_rng_state_all(rng["cuda"]) |
There was a problem hiding this comment.
Using torch.cuda.get_rng_state_all() and torch.cuda.set_rng_state_all() saves and restores the RNG states of all visible CUDA devices on the node. This creates a strict dependency on the hardware configuration: if a checkpoint is saved on an 8-GPU node and resumed on a 4-GPU or 1-GPU node (or vice versa), torch.cuda.set_rng_state_all will raise a RuntimeError due to a mismatch in the number of devices.
Since each distributed training process only utilizes its current active GPU (set via torch.cuda.set_device), we should only save and restore the RNG state of the current active device using torch.cuda.get_rng_state() and torch.cuda.set_rng_state(). This ensures hardware portability of checkpoints.
| @staticmethod | |
| def _rng_state() -> dict: | |
| rng = {"cpu": torch.get_rng_state()} | |
| if torch.cuda.is_available(): | |
| rng["cuda"] = torch.cuda.get_rng_state_all() | |
| return rng | |
| @staticmethod | |
| def _set_rng_state(rng: dict) -> None: | |
| if rng.get("cpu") is not None: | |
| torch.set_rng_state(rng["cpu"]) | |
| if rng.get("cuda") is not None and torch.cuda.is_available(): | |
| torch.cuda.set_rng_state_all(rng["cuda"]) | |
| @staticmethod | |
| def _rng_state() -> dict: | |
| rng = {"cpu": torch.get_rng_state()} | |
| if torch.cuda.is_available(): | |
| rng["cuda"] = torch.cuda.get_rng_state() | |
| return rng | |
| @staticmethod | |
| def _set_rng_state(rng: dict) -> None: | |
| if rng.get("cpu") is not None: | |
| torch.set_rng_state(rng["cpu"]) | |
| if rng.get("cuda") is not None and torch.cuda.is_available(): | |
| torch.cuda.set_rng_state(rng["cuda"]) |
| device = ( | ||
| per_pos_correct.device | ||
| if per_pos_correct is not None | ||
| else torch.device("cuda", torch.cuda.current_device()) | ||
| ) |
There was a problem hiding this comment.
If a distributed process group is initialized on CPU (e.g., using the Gloo backend for testing or CPU-only environments), calling torch.cuda.current_device() will raise an error.
We should check if CUDA is available before attempting to get the current CUDA device, falling back to CPU if it is not.
| device = ( | |
| per_pos_correct.device | |
| if per_pos_correct is not None | |
| else torch.device("cuda", torch.cuda.current_device()) | |
| ) | |
| device = ( | |
| per_pos_correct.device | |
| if per_pos_correct is not None | |
| else torch.device("cuda", torch.cuda.current_device()) if torch.cuda.is_available() else torch.device("cpu") | |
| ) |
| def _point(self, name: str, ckpt_dir: str) -> None: | ||
| link = os.path.join(self.output_dir, name) | ||
| if os.path.islink(link) or os.path.exists(link): | ||
| try: | ||
| os.remove(link) | ||
| except OSError: | ||
| return # a non-symlink collision: leave it, the dir is source of truth | ||
| os.symlink(os.path.abspath(ckpt_dir), link) |
There was a problem hiding this comment.
Creating a symlink via os.symlink can fail due to filesystem limitations, OS-specific restrictions (e.g., lack of symlink privileges on Windows), or permission issues. Since symlinks are convenience pointers (latest and best), failing to create them should not crash an active training run.
We should wrap os.symlink in a try-except OSError block to ensure that any failure to create the symlink is gracefully handled (e.g., logged as a warning) rather than crashing the training process.
| def _point(self, name: str, ckpt_dir: str) -> None: | |
| link = os.path.join(self.output_dir, name) | |
| if os.path.islink(link) or os.path.exists(link): | |
| try: | |
| os.remove(link) | |
| except OSError: | |
| return # a non-symlink collision: leave it, the dir is source of truth | |
| os.symlink(os.path.abspath(ckpt_dir), link) | |
| def _point(self, name: str, ckpt_dir: str) -> None: | |
| link = os.path.join(self.output_dir, name) | |
| if os.path.islink(link) or os.path.exists(link): | |
| try: | |
| os.remove(link) | |
| except OSError: | |
| return # a non-symlink collision: leave it, the dir is source of truth | |
| try: | |
| os.symlink(os.path.abspath(ckpt_dir), link) | |
| except OSError: | |
| pass # gracefully ignore symlink creation failures to avoid crashing the run |
| def _rotate(self) -> None: | ||
| if self.max_checkpoints <= 0: | ||
| return | ||
| dirs = sorted(self._all_checkpoints(), key=lambda kv: kv[0]) | ||
| for step, path in dirs[: -self.max_checkpoints]: | ||
| if step == self.best_step: | ||
| continue # never rotate away the tracked best | ||
| for f in glob.glob(os.path.join(path, "*")): | ||
| os.remove(f) | ||
| os.rmdir(path) |
There was a problem hiding this comment.
If deleting an old checkpoint file or directory fails (e.g., due to file locks, transient filesystem issues, or permission errors), the raised OSError will crash the active training run. Since checkpoint rotation is a non-critical cleanup operation, failures should be handled gracefully.
We should wrap the file and directory removal calls in a try-except OSError block to prevent cleanup failures from crashing the training.
| def _rotate(self) -> None: | |
| if self.max_checkpoints <= 0: | |
| return | |
| dirs = sorted(self._all_checkpoints(), key=lambda kv: kv[0]) | |
| for step, path in dirs[: -self.max_checkpoints]: | |
| if step == self.best_step: | |
| continue # never rotate away the tracked best | |
| for f in glob.glob(os.path.join(path, "*")): | |
| os.remove(f) | |
| os.rmdir(path) | |
| def _rotate(self) -> None: | |
| if self.max_checkpoints <= 0: | |
| return | |
| dirs = sorted(self._all_checkpoints(), key=lambda kv: kv[0]) | |
| for step, path in dirs[: -self.max_checkpoints]: | |
| if step == self.best_step: | |
| continue # never rotate away the tracked best | |
| try: | |
| for f in glob.glob(os.path.join(path, "*")): | |
| os.remove(f) | |
| os.rmdir(path) | |
| except OSError: | |
| pass # ignore cleanup failures to avoid crashing the run |
| from specforge.runtime.training.trainer import TrainerController, TrainerCore | ||
|
|
||
| TTT, BS, TOTAL, CUT = 3, 2, 6, 3 | ||
| workdir = tempfile.mkdtemp(prefix="ckpt_continuity_") |
There was a problem hiding this comment.
The temporary directory created with tempfile.mkdtemp is never cleaned up, which leads to disk space accumulation (resource leaks) on the test runner or developer machine.
We should use self.addCleanup with shutil.rmtree to ensure the directory is cleaned up after the test finishes.
| workdir = tempfile.mkdtemp(prefix="ckpt_continuity_") | |
| workdir = tempfile.mkdtemp(prefix="ckpt_continuity_") | |
| import shutil | |
| self.addCleanup(shutil.rmtree, workdir) |
| def test_rotation_pointers_and_best(self): | ||
| from specforge.training.checkpoint import CheckpointManager | ||
|
|
||
| out = tempfile.mkdtemp(prefix="ckpt_mgr_") |
There was a problem hiding this comment.
The temporary directory created with tempfile.mkdtemp is never cleaned up, leading to disk space accumulation.
We should use self.addCleanup with shutil.rmtree to ensure the directory is cleaned up after the test finishes.
| out = tempfile.mkdtemp(prefix="ckpt_mgr_") | |
| out = tempfile.mkdtemp(prefix="ckpt_mgr_") | |
| import shutil | |
| self.addCleanup(shutil.rmtree, out) |
| def test_no_sync_matches_per_step_reduction(self): | ||
| import torch.multiprocessing as mp | ||
|
|
||
| workdir = tempfile.mkdtemp(prefix="no_sync_equiv_") |
There was a problem hiding this comment.
The temporary directory created with tempfile.mkdtemp is never cleaned up, leading to disk space accumulation.
We should use self.addCleanup with shutil.rmtree to ensure the directory is cleaned up after the test finishes.
| workdir = tempfile.mkdtemp(prefix="no_sync_equiv_") | |
| workdir = tempfile.mkdtemp(prefix="no_sync_equiv_") | |
| import shutil | |
| self.addCleanup(shutil.rmtree, workdir) |
…fe Evaluator, durable best tracking Post-review fixes for #637 (adversarial review vs the #630 roadmap): Multi-rank resume correctness: - CheckpointManager.save now writes each rank's optimizer/RNG to its own training_state_rank{r}.pt beside the rank0 shared payload (draft weights + counters + world_size); under FSDP use_orig_params the AdamW moments live on rank-local shard views, so persisting only rank0's copy and restoring it everywhere corrupted the other ranks' moments and collapsed their RNG streams. read_resume_state hands each rank back its own state and fails fast on a world-size mismatch (and on legacy single-file checkpoints at world_size>1). Resume repositions the data stream (plan.md G1 seek()-equivalent): - TrainerController tracks epoch_batch, persists it, and skips the consumed prefix of the interrupted epoch on resume (via the new FeatureDataLoader.seek — no feature materialization — or islice for plain iterables). The domain Trainer threads it for refs-mode runs; the online queue path keeps control-plane skip_ids reconciliation. - test_resume_loss_curve_continuity now trains on DISTINCT batches, so resuming on the wrong data cannot pass (the old fixed-batch form was structurally blind to the data position). Evaluator DP correctness: - The collective schedule is decided globally (SUM of scalar sums, MAX of the per-position length, then ONE stacked count reduce) so a rank with an empty or scalar-only shard issues the same collectives as its peers — no NCCL desync. Scalar accuracy (DFlash/Domino) is now reduced across ranks like everything else. Collectives use the local device via specforge.utils.get_local_device (CPU for gloo). Accumulation stays on-device (one host sync after the loop), and eval/per_position_acc is reported alongside the folded acc-len. - New 2-process gloo gate: cross-rank scalar reduction + ragged-shard schedule symmetry (test_evaluator_aggregation). Durable, decoupled best tracking: - CheckpointManager rehydrates best_score/best_step from best_meta.json (now also carrying "score") on construction, so a restarted process neither rotates away the on-disk best nor lets a worse score overwrite it; update_best is split into score()/is_better(). - fit() tracks the best on EVERY eval (when checkpointing is enabled), persisting a checkpoint on demand when the best eval lands off the save cadence — previously best only fired when eval_interval and save_interval coincided on the same step. Surface + cleanup: - launch: every builder now forwards resume_from/max_checkpoints, so resume and rotation are reachable from the build_* entry points. - TrainerController: drop the max_checkpoints/checkpoint_manager dual config (inject a configured manager; the domain Trainer does); remove the dead eval_step + mode plumbing (evaluate goes through Evaluator on raw forward_loss; validate_batch still runs inside every strategy's forward_loss); drop the did_eval/last_eval_metrics threading. - Trainer._load_resume_state removed: CheckpointManager.read_resume_state is the single checkpoint reader; the loaded dict is dropped after the weight copy instead of living through the FSDP wrap. - CheckpointManager._rotate uses shutil.rmtree; symlink creation is guarded for filesystems without symlink support (dirs + best_meta.json stay the source of truth); save_checkpoint filters draft weights on rank0 only. - test_no_sync_equiv now also counts no_sync deferrals (exactly accumulation_steps-1 per optimizer step) — the roadmap's "one all-reduce per optimizer step" gate, not just weight equality. - DESIGN.md updated (per-rank layout, seek, best semantics; fixed the inverted no_sync sentence). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
E1's evaluator/best-tracking half landed with Phase D (#637); the EvalConfig/EvalCache half is the E1 PR, with cache wiring into the run surface deferred to Phase E's config+CLI (noted explicitly). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
|
One more commit from the E0/E1 review pass, placed here because it fixes this PR's code: |
…backward TrainerCore now computes the optimizer-step boundary BEFORE calling backward and passes it as backward(loss, is_boundary=...). FSDPTrainingBackend runs non-boundary micro-steps under no_sync() so the gradient reduce-scatter fires once per optimizer step instead of once per micro-batch — identical math, one collective per step. GPU equivalence gate: test_no_sync_equiv. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…, per-rank, run-scoped)
FSDPTrainingBackend.state_dict now returns the full {model, optimizer, rng}
training state: model gathered rank0-only with CPU offload, optimizer/RNG kept
rank-local (single bound-device CUDA state). BF16Optimizer persists its fp32
masters so resume does not re-quantize from bf16. CheckpointManager owns the
on-disk layout: atomic run-scoped step dirs, per-rank state files, latest/best
pointers with rotation. Eagle3's filter drops the embedding only when frozen.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
… drift fail-fasts Trainer(resume_from=...) restores draft weights BEFORE the FSDP wrap, then the rank-local optimizer/RNG shard after, and fail-fasts on strategy / dataset_size / accumulation_steps / weight-key drift. TrainerController tracks the live epoch position (batches + batch-size-independent samples), skips the consumed prefix via seek()/islice on resume or fit() re-entry, and returns early at max_steps. save_checkpoint goes through CheckpointManager (shared payload rank0-only, per-rank optimizer/RNG files). launch.py forwards the resume knobs and validates online resume; its docstring slimming rides here. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…dule, accuracy_denom weighting TrainerCore.eval_step is gone: it scalarized away the per-position count tensors correct acc-len aggregation needs. TrainerController.evaluate now runs Evaluator.run over raw forward_loss outputs (weighted global counts first, geometric acc-len sum after DP reduction) at the live StepContext. fit() broadcasts rank0's eval-enabled verdict so no rank enters the evaluator's collectives alone, and logs eval/* metrics every eval step. DFlash/Domino models now surface accuracy_denom so eval weights accuracy exactly. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
fit() now tracks the best eval score: is_better runs only on non-empty, DP-reduced eval metrics (its guard is rank-identical; the rank0 verdict broadcast lives inside the manager), a best step gets a checkpoint persisted on demand even when the eval and save intervals never coincide, and update_best records the pointer after the save. DESIGN/ARCHITECTURE docs updated for the Phase D lifecycle. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
3ef9f34 to
f512828
Compare
Phase D — training managers (no_sync, full resume, checkpoint/eval)
Implements Phase D of the reconciled plan (#630,
docs/roadmap/domain-refactor.md§D): brings the training loop to production parity. Stacked on #636 (Phase C). Four features plus best-checkpoint tracking, each listed below with its full API surface and how it is implemented.1. Gradient accumulation with
no_sync()Where:
specforge/runtime/training/trainer.py(TrainerCore),specforge/runtime/training/backend.py(FSDPTrainingBackend).TrainerCore.train_step(batch, ctx)accumulation_steps, decides the optimizer boundary before backward (self._micro % accumulation_steps == 0) and passes it down, steps the optimizer only at the boundary.TrainingBackend.backward(loss, *, is_boundary=True)(ABC)FSDPTrainingBackend.backwardself.module.no_sync()— the FSDP gradient reduce-scatter fires once per optimizer step, not once per micro-step.accumulation_steps=1/ unwrapped modules are byte-identical to before.Deliberately removed:
TrainerCore.eval_stepand themodeplumbing — evaluation now consumes rawforward_lossoutputs (see §4) because_resultscalarizes away the per-position count tensors that correct acceptance-length aggregation needs.2. Full resume
Where:
specforge/runtime/training/backend.py,specforge/optimizer.py,specforge/training/trainer.py(domain Trainer),specforge/runtime/data_plane/feature_dataloader.py.FSDPTrainingBackend.state_dict(){model, optimizer, rng}.modelis gathered underFullStateDictConfig(offload_to_cpu=True, rank0_only=True)(no per-rank GPU replica; non-rank0 gets{}).rngholds the CPU state plus the bound CUDA device's state only — checkpoints are independent ofCUDA_VISIBLE_DEVICESlayout.FSDPTrainingBackend.load_state_dict(state)BF16Optimizer.state_dict/load_state_dictfp32_params(older checkpoint) falls back to re-clone with one warning.Trainer(resume_from=...)(domain)read_resume_state→ draft weights before the FSDP wrap (so the fp32 master is rebuilt from them) → optimizer/scheduler/RNG afterprepare_model→start_step/epoch/data position. Weight load is guarded: non-empty key overlap, zero unexpected keys, and saved-strategy-name match, elseValueError(a mismatched checkpoint can no longer silently "resume" from random init). Persists and validatesdataset_sizeandaccumulation_steps— drift fails fast.FeatureDataLoader.seek(num_batches)TrainerController.fit()global_step >= max_stepsbefore consuming any data (idempotent restart); the mid-epoch skip is driven by the live position counter so re-enteringfit()never re-trains the prefix; the plain-iterable skip path validates the consumed count.3. CheckpointManager
Where:
specforge/training/checkpoint.py(new module).CheckpointManager(output_dir, run_id, *, max_checkpoints=0, best_metric='eval/simulated_acc_len', best_min_delta=0.0)checkpoint_dir(step){output_dir}/{run_id}-step{N}/containingtraining_state.pt(rank0 shared payload) +training_state_rank{r}.ptper rank — required because FSDPuse_orig_paramsshards AdamW moments; rank0-only optimizer state would corrupt every other rank on resume.save(state, step, *, rank_state=None)makedirs; atomic writes (tmp +os.replace) for every payload; timeline rewind: rank0 deletes any on-disk step ≥ the one being saved (fork/rollback semantics, logged) and clears a best record inside the deleted range; cross-rank failure propagation: anall_gatherof per-rank error strings makes one rank's FS error raise on all ranks instead of stranding the group in a barrier; rotation never deletesbest_stepor the just-written step, and itsrmtreefailures warn instead of killing the run.is_better(eval_metrics)score > best_score + best_min_deltaand the verdict is broadcast inside the manager — callers can't accidentally diverge rank decisions before the collective-bearing save. Empty metrics →False.update_best(step, eval_metrics){run_id}.best_meta.jsonwrite +{run_id}-bestrepoint. Meta recordsrun_idandmetric; the loader ignores mismatched or corrupt meta with one warning. All pointer/meta artifacts are run_id-scoped so runs sharing anoutput_dircannot clobber each other.read_resume_state(path, *, map_location='cpu', require_full_state=True)'backend'(no key-rename seams). Fails fast on world-size mismatch and on optimizer-less checkpoints (require_full_state=Falseopts into weights+counters-only).load(step=None)/latest_dir()latestis trusted only when it is a real symlink pointing at a complete checkpoint; a materialized (copied)latestis repaired on the next save; the fallback scans only complete step dirs, so a truncated dir from a mid-save preemption is invisible. Relative symlinks keep pointers valid after relocatingoutput_dir.4. Evaluator
Where:
specforge/eval/evaluator.py(new module).Evaluator.run(model, batches, strategy, ...)simulated_acc_len, so the metric is batch-size- and rank-invariant (never a mean of ratios). Emitseval/avg_loss,eval/avg_acc,eval/simulated_acc_len, plus aggregated legacyeval/acceptance_rate_i/eval/ploss_iwhen the strategy provides them.{}— never fabricated zeros, so an empty/drained eval set cannot poison best tracking.forward_lossnow emitaccuracy_denom(the accuracy's own denominator); the Evaluator weights scalar accuracy by it, restoring batch-size invariance for real scalar shapes (loss-token fallback documented).TrainerController.evaluate(data)forward_lossoutputs to the Evaluator under the sameStepContextas training (Domino's schedule-dependent loss evaluates consistently);data=Nonejoins the collectives with an empty shard. Eval entry itself is a rank0-broadcast decision, and eval metrics are logged and merged intolast_metrics.5. Best-checkpoint tracking
Any eval that beats the record (by more than
best_min_delta) persists a checkpoint on demand — decoupled fromsave_intervalcadence. The record survives restarts via meta rehydration, is rank-agreed by construction (§3is_better), and rotation always protects it.Wiring
Every
build_*entry point inspecforge/runtime/launch.pyforwardsresume_from/max_checkpoints. Disagg-online builders fail fast whenresume_fromis given withoutresume=True, and when the durable ack marker is ahead of the checkpoint (acked-but-not-checkpointed samples would be silently skipped — the reconcile docstring now states the accurate guarantee). New managers live in their S-homes (training/,eval/); the runtime seam imports them lazily.Tests / gates
test_no_sync_equiv.py— 2-rank FSDP, distinct per-rank data: the accumulation path matches per-micro-step reduction, and deferral counting pins exactlyaccumulation_steps−1deferrals per optimizer step.test_checkpoint_manager.py(new) — atomicity (truncated dirs invisible), timeline rewind, run-scoped meta/pointers, corrupt/mismatched meta, glob-metachar run_ids, materialized-latestrepair,is_better/update_bestsemantics incl.best_min_delta,read_resume_statebranches, and a 2-process gloo save→per-rank-read gate (dynamic port).test_checkpoint_resume.py— bit-for-bit weight round-trip with an asserted missing-key set; resume continuity vs an uninterrupted reference on distinct batches with an explicit LR-scheduler position assert; the productionTrainer(resume_from=...)entrypoint end-to-end; fail-fast gates (strategy mismatch, zero-overlap, dataset-size drift, accumulation-steps drift, world-size mismatch);max_stepsidempotent restart; re-enteredfit()prefix safety.test_evaluator_aggregation.py— aggregate-before-geometric-sum, batch-size invariance under realistic mismatched denominators, zero-batch →{}, float64 counts, 2-proc gloo ragged/empty-shard schedule (scoped note: covers the Evaluator's own collectives).test_feature_dataloader.py/test_trainer.py— seek semantics incl. bounds + non-materialization, boundary-flag sequence on CPU, eval metrics reach the logger, checkpoint payload extras.Validation
tests/test_runtimesuite: 241 passed, 24 skipped (CUDA/multi-GPU gates), 1 xfailed locally after the review-fix pass.Review history
Three adversarial review rounds are folded in:
max_stepsidempotency, embed-filter freeze check, online resume knob coupling — plus a comment-slimming pass across all touched modules and the test hardening listed above.🤖 Generated with Claude Code
How to review — commit by commit
The branch is structured as 5 self-contained behavioral commits, each with its own tests and a green CPU suite at that point. Reviewing commit-by-commit is the intended path:
6e1b739eb5e263{model, optimizer(+fp32 masters), rng}; atomic, per-rank, run-scoped layoute96441cseek-based data repositioning, drift fail-fastsfa882d1accuracy_denomweightingf512828Rebased onto
dataflow-up-16-zerocopyafter the B/C stack (#622–#636) merged; adopts thebuild_control_plane_for_moderename from the Phase C review fixes.