Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 29 additions & 40 deletions specforge/runtime/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,12 @@
MetadataStore,
SQLiteMetadataStore,
)
from specforge.runtime.data_plane import (
FeatureDataLoader,
FeatureStore,
LocalFeatureStore,
)
from specforge.runtime.training.backend import FSDPTrainingBackend, ParallelConfig
from specforge.runtime.data_plane import FeatureStore, LocalFeatureStore
from specforge.runtime.training.registry import StrategySpec, resolve_strategy
from specforge.runtime.training.trainer import TrainerController, TrainerCore

# The trainer/loader assembly (FeatureDataLoader + FSDPTrainingBackend +
# TrainerCore + TrainerController) now lives in the domain ``Trainer``
# (``specforge.training``); ``_assemble_trainer`` below delegates to it.

# ---------------------------------------------------------------------------
# Shared assemblers — strategy- and topology-agnostic. Every builder is a thin
Expand Down Expand Up @@ -93,48 +91,39 @@ def _assemble_trainer(
the optimizer-step ack — are identical. ``optimizer_factory`` runs AFTER
FSDP-wrap, over the wrapped module's inner draft.
"""
# Offline = a fixed, re-iterable ref set: record committed state so the ack
# lookup works (num_epochs > 1 then re-iterates). Online streams refs through
# a queue and commits them elsewhere (rollout / channel).
if "refs" in ref_source:
controller.enqueue_offline_refs(ref_source["refs"])
trainer_id = controller.register_trainer({"role": "trainer", "run_id": run_id})
loader = FeatureDataLoader(
store,
**ref_source,
batch_size=batch_size,
collate_fn=collate_fn,
per_sample_transform=per_sample_transform,
drop_last=True,
strategy=spec.name,
)

parallel = ParallelConfig.from_distributed(
tp_size=tp_size, sp_ulysses_size=sp_ulysses_size, sp_ring_size=sp_ring_size
)
backend = FSDPTrainingBackend(parallel, optimizer_factory=optimizer_factory)
# FSDP-wrap the composite model and build the optimizer over the inner draft
# AFTER wrapping; the strategy MUST run forward through the wrapped module so
# FSDP is actually in the forward/backward path (not bypassed at >1 rank).
wrapped = backend.prepare_model(model, optimizer_target=model.draft_model)
strategy = spec.make_strategy(wrapped, target_head=target_head)
core = TrainerCore(strategy, backend, accumulation_steps=accumulation_steps)
trainer = TrainerController(
core,
# Delegates to the domain Trainer (``specforge.training``) — the canonical
# assembler for this seam since Phase B3. It performs the exact composition
# this function used to inline; we return the same
# (TrainerController, FeatureDataLoader) tuple so every build_* path is
# unchanged. New code can build a ``Trainer`` directly and call ``.fit()``.
from specforge.training import Trainer

trainer = Trainer(
spec=spec,
controller=controller,
store=store,
ref_source=ref_source,
model=model,
target_head=target_head,
optimizer_factory=optimizer_factory,
run_id=run_id,
output_dir=output_dir,
batch_size=batch_size,
accumulation_steps=accumulation_steps,
num_epochs=num_epochs,
max_steps=max_steps,
total_steps=total_steps,
save_interval=save_interval,
eval_interval=eval_interval,
log_interval=log_interval,
tp_size=tp_size,
sp_ulysses_size=sp_ulysses_size,
sp_ring_size=sp_ring_size,
logger=logger,
ack_fn=lambda ids, step: controller.ack_train_refs(
trainer_id, ids, global_step=step, optimizer_durable=True
),
log_interval=log_interval,
collate_fn=collate_fn,
per_sample_transform=per_sample_transform,
)
return trainer, loader
return trainer.controller, trainer.loader


def _offline_io(spec: StrategySpec, max_len: int):
Expand Down
30 changes: 30 additions & 0 deletions specforge/training/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# coding=utf-8
# Copyright 2024 The SpecForge team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
"""Domain training layer (Phase B): the caller-facing lifecycle over the runtime.

``Trainer`` WRAPS — does not replace — the runtime training seam
(``TrainerController`` / ``TrainerCore`` / ``DraftTrainStrategy`` /
``FSDPTrainingBackend``). Future managers (CheckpointManager / Evaluator /
lr_scheduler — Phase D) land here on top of the same seam.

Import-light: the ``Trainer`` (which imports the GPU/model-heavy runtime backend)
is imported lazily so ``import specforge.training`` stays cheap.
"""

from __future__ import annotations

__all__ = ["Trainer"]


def __getattr__(name): # PEP 562 lazy re-export
if name == "Trainer":
from .trainer import Trainer

return Trainer
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
140 changes: 140 additions & 0 deletions specforge/training/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# coding=utf-8
# Copyright 2024 The SpecForge team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
"""Domain ``Trainer``: the caller-facing training object (Phase B).

``Trainer`` composes — behind one object with a ``.fit()`` — exactly what
``launch._assemble_trainer`` wired inline before:

ref source + FeatureStore
-> FeatureDataLoader(transform, collate) (the data path)
model + spec.make_strategy
-> FSDPTrainingBackend.prepare_model (FSDP wrap)
-> TrainerCore -> TrainerController (the trainer seam)

The runtime seam (``TrainerController`` / ``TrainerCore`` /
``DraftTrainStrategy`` / ``FSDPTrainingBackend``) is byte-for-byte unchanged —
this is the domain facade over it, so ``launch._assemble_trainer`` now delegates
here (one wiring path, no fork). The online / offline / disaggregated distinction
is invisible to ``Trainer``: it is fully absorbed by the (ref source +
``FeatureStore``) it is handed, behind ``FeatureDataLoader -> TrainBatch``. There
is NO ``HiddenStateStream`` — the loader is the stream.
"""

from __future__ import annotations

from typing import Optional

from specforge.runtime.data_plane import FeatureDataLoader, FeatureStore
from specforge.runtime.training.backend import FSDPTrainingBackend, ParallelConfig
from specforge.runtime.training.registry import StrategySpec, resolve_strategy
from specforge.runtime.training.trainer import TrainerController, TrainerCore


class Trainer:
"""Domain training lifecycle wrapping the runtime controller/core seam."""

def __init__(
self,
*,
spec: StrategySpec,
controller, # runtime.control_plane.DataFlowController (metadata only)
store: FeatureStore,
ref_source: dict, # {"refs": [...]} (offline) | {"queue": q} (online)
model,
target_head,
optimizer_factory,
run_id: str,
output_dir: str,
batch_size: int,
accumulation_steps: int,
num_epochs: int,
max_steps: Optional[int],
save_interval: int,
eval_interval: int,
tp_size: int,
sp_ulysses_size: int,
sp_ring_size: int,
logger,
log_interval: int,
collate_fn,
total_steps: Optional[int] = None,
per_sample_transform=None,
):
# Offline = a fixed, re-iterable ref set: record committed state so the ack
# lookup works (num_epochs > 1 then re-iterates). Online streams refs through
# a queue and commits them elsewhere (rollout / channel).
if "refs" in ref_source:
controller.enqueue_offline_refs(ref_source["refs"])
trainer_id = controller.register_trainer({"role": "trainer", "run_id": run_id})
loader = FeatureDataLoader(
store,
**ref_source,
batch_size=batch_size,
collate_fn=collate_fn,
per_sample_transform=per_sample_transform,
drop_last=True,
strategy=spec.name,
)

parallel = ParallelConfig.from_distributed(
tp_size=tp_size, sp_ulysses_size=sp_ulysses_size, sp_ring_size=sp_ring_size
)
backend = FSDPTrainingBackend(parallel, optimizer_factory=optimizer_factory)
# FSDP-wrap the composite model and build the optimizer over the inner draft
# AFTER wrapping; the strategy MUST run forward through the wrapped module so
# FSDP is actually in the forward/backward path (not bypassed at >1 rank).
wrapped = backend.prepare_model(model, optimizer_target=model.draft_model)
strategy = spec.make_strategy(wrapped, target_head=target_head)
core = TrainerCore(strategy, backend, accumulation_steps=accumulation_steps)
controller_obj = TrainerController(
core,
run_id=run_id,
output_dir=output_dir,
num_epochs=num_epochs,
max_steps=max_steps,
total_steps=total_steps,
save_interval=save_interval,
eval_interval=eval_interval,
log_interval=log_interval,
logger=logger,
ack_fn=lambda ids, step: controller.ack_train_refs(
trainer_id, ids, global_step=step, optimizer_durable=True
),
)

# The runtime pieces, exposed for callers that still want them directly
# (and for launch._assemble_trainer's (controller, loader) tuple).
self.spec = spec
self.dataflow_controller = controller
self.trainer_id = trainer_id
self.backend = backend
self.core = core
#: the runtime TrainerController (has fit / evaluate / save_checkpoint)
self.controller = controller_obj
#: the FeatureDataLoader -> TrainBatch iterator (the canonical "stream")
self.loader = loader

@classmethod
def from_strategy_name(cls, strategy: str, **kwargs) -> "Trainer":
"""Resolve the :class:`StrategySpec` by name, then assemble."""
return cls(spec=resolve_strategy(strategy), **kwargs)

def fit(self, eval_data=None) -> int:
"""Run the training loop over the loader; returns the final global step."""
return self.controller.fit(self.loader, eval_data=eval_data)

def evaluate(self, data=None):
return self.controller.evaluate(self.loader if data is None else data)

def save_checkpoint(self, step: Optional[int] = None):
step = self.controller.global_step if step is None else step
return self.controller.save_checkpoint(step)


__all__ = ["Trainer"]
Loading
Loading