Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 252 additions & 18 deletions specforge/runtime/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from specforge.runtime.contracts import SampleRef
from specforge.runtime.control_plane import DataFlowController
from specforge.runtime.control_plane.metadata_store import (
InMemoryMetadataStore,
MetadataStore,
SQLiteMetadataStore,
)
Expand Down Expand Up @@ -455,11 +456,16 @@ def build_disagg_online_producer(
(commit/ack state stays private to this process). See
:func:`_resolve_metadata_store`.

Returns ``(workers, drive_producer)``. ``drive_producer()`` runs the workers
until the prompt pool drains, publishing refs to the channel and applying
backpressure (it pauses while ``channel.in_flight_remote()`` exceeds
Returns ``(workers, drive_producer)``. ``drive_producer(should_stop=...)`` runs
the workers until the prompt pool drains, publishing refs to the channel and
applying backpressure (it pauses while ``channel.in_flight_remote()`` exceeds
``in_flight_high_watermark`` so a lagging trainer can't overrun the Mooncake
segment), then closes the channel so the consumer's loader terminates.
``should_stop`` (a zero-arg predicate) lets a caller wind the producer down
early — e.g. the interleaved driver sets it once the trainer hits ``max_steps``
so the producer doesn't block forever on the watermark after the consumer
stops draining (O1.2). The channel is always closed on exit so the consumer
never hangs on a finished producer.
"""
import time

Expand Down Expand Up @@ -500,24 +506,30 @@ def build_disagg_online_producer(
for i in range(num_rollout_workers)
]

def drive_producer(max_rounds: int = 1_000_000) -> int:
def drive_producer(max_rounds: int = 1_000_000, should_stop=None) -> int:
for w in workers:
w.start()
produced = 0
for _ in range(max_rounds):
# backpressure: don't let the producer outrun the consumer (and the
# Mooncake segment). in_flight_remote = published - consumer-acked.
while channel.in_flight_remote() >= in_flight_high_watermark:
sleep(backpressure_poll_s)
refs = []
for w in workers:
refs.extend(w.run_once(max_tasks=lease))
if not refs:
break # prompt pool drained
channel.publish_many(refs)
produced += len(refs)
channel.close() # EOF -> the consumer's loader terminates once drained
return produced
try:
for _ in range(max_rounds):
if should_stop is not None and should_stop():
break # caller asked us to wind down (e.g. trainer finished)
# backpressure: don't let the producer outrun the consumer (and
# the Mooncake segment). in_flight = published - consumer-acked.
while channel.in_flight_remote() >= in_flight_high_watermark:
if should_stop is not None and should_stop():
return produced # don't block on the watermark forever
sleep(backpressure_poll_s)
refs = []
for w in workers:
refs.extend(w.run_once(max_tasks=lease))
if not refs:
break # prompt pool drained
channel.publish_many(refs)
produced += len(refs)
return produced
finally:
channel.close() # EOF -> the consumer's loader terminates once drained

return workers, drive_producer

Expand Down Expand Up @@ -628,6 +640,226 @@ def build_disagg_online_consumer(
return trainer, loader


def run_disagg_online_interleaved(
*,
trainer,
loader,
drive_producer,
channel,
producer_max_rounds: int = 1_000_000,
join_timeout_s: Optional[float] = 30.0,
) -> int:
"""Run an online producer and the trainer CONCURRENTLY (O1.2).

Replaces the synchronous drain-then-fit shape (generate the whole prompt
pool, *then* train) with a live loop: the producer streams refs on a
background thread while ``trainer.fit`` consumes them on this thread. The
consumer's :class:`StreamingRefQueue` blocks until the channel is
closed-and-drained, so the trainer tracks the producer instead of ending the
instant the stream is momentarily empty.

Shutdown is symmetric and hang-free:

* trainer finishes first (e.g. ``max_steps``) -> ``should_stop`` is set, so
the producer stops generating instead of blocking on the in-flight
watermark after the consumer quit draining;
* producer finishes first (prompts drained) -> it closes the channel, so the
loader drains the tail and ``fit`` returns;
* producer raises -> the channel is closed (``drive_producer``'s ``finally``)
so the consumer cannot hang, and the exception is re-raised here once
``fit`` has unwound.

Returns the trainer's final optimizer step. Single process, in-process
generator stub — no Ray, no live SGLang server (those are O1.3 / O2).
"""
import threading

stop = threading.Event()
err: dict = {}

def _produce() -> None:
try:
drive_producer(producer_max_rounds, should_stop=stop.is_set)
except BaseException as exc: # surfaced to the main thread below
err["exc"] = exc
channel.close() # never leave the consumer blocked on a dead producer

thread = threading.Thread(
target=_produce, name="disagg-online-producer", daemon=True
)
thread.start()
trainer_exc: Optional[BaseException] = None
try:
step = trainer.fit(loader)
except BaseException as exc: # noqa: BLE001 - re-raised below, chained w/ producer
trainer_exc = exc
finally:
stop.set() # trainer done (or failed) -> tell the producer to wind down
thread.join(timeout=join_timeout_s)

producer_exc = err.get("exc")
if thread.is_alive():
# The producer overran join_timeout_s and is still running: a daemon thread
# that would keep publishing into a store no consumer drains. Fail loudly
# instead of returning "success" with a leaked, still-live producer.
msg = (
f"disagg online producer did not wind down within {join_timeout_s}s of "
"trainer exit (still alive); abandoning it would leak an active rollout"
)
if trainer_exc is not None:
raise RuntimeError(msg) from trainer_exc
raise RuntimeError(msg)
# A producer failure closes the channel, which is usually what makes trainer.fit
# fail downstream, so surface the producer exception as the root cause and chain
# the trainer error so neither is silently lost.
if producer_exc is not None:
if trainer_exc is not None:
raise producer_exc from trainer_exc
raise producer_exc
if trainer_exc is not None:
raise trainer_exc
return step


def build_disagg_online_eagle3_runtime(
*,
target_model,
prompts,
eagle3_model,
optimizer_factory,
feature_store: FeatureStore,
run_id: str,
output_dir: str,
target_hidden_size: int,
channel=None,
ref_channel_path: Optional[str] = None,
target_vocab_size: Optional[int] = None,
draft_vocab_size: Optional[int] = None,
target_repr: str = "logits",
aux_hidden_state_layer_ids=None,
vocab_map_version: Optional[str] = None,
t2d=None,
num_rollout_workers: int = 1,
device: str = "cuda",
lease: int = 8,
in_flight_high_watermark: int = 256,
batch_size: int = 1,
accumulation_steps: int = 1,
num_epochs: int = 1,
max_steps: Optional[int] = None,
save_interval: int = 0,
eval_interval: int = 0,
tp_size: int = 1,
sp_ulysses_size: int = 1,
sp_ring_size: int = 1,
collate_fn=None,
idle_timeout_s: Optional[float] = None,
metadata_store: Optional[MetadataStore] = None,
metadata_db_path: Optional[str] = None,
resume: bool = False,
join_timeout_s: Optional[float] = 30.0,
logger=None,
log_interval: int = 50,
):
"""One-process online disaggregated EAGLE3 runtime (O1.2).

The single named builder the roadmap calls for: it wires the producer
(rollout pool, in-process ``generate_eagle3_data`` stub via ``SGLangAdapter``)
and the consumer (FSDP trainer) over ONE shared metadata store, ONE
consume-once ``feature_store``, and ONE streaming-ref channel (two
:class:`StreamingRefChannel` views over the same path — the proven
producer/consumer split), and returns ``(trainer, loader, run)``. Calling
``run()`` drives both concurrently (:func:`run_disagg_online_interleaved`) —
the live loop that replaces drain-then-fit. No live SGLang server and no Ray
yet (O1.3 / O2); this proves the data + control paths live with a stubbed
generator.

Pass a ``channel`` or a ``ref_channel_path`` (a ``StreamingRefChannel`` is
built over it). The metadata store defaults to a shared in-process store
(enough for one process); pass ``metadata_store`` / ``metadata_db_path`` for a
durable, restart-reconcilable run (``resume=True`` then skips already-trained
refs on the channel re-read).
"""
from specforge.runtime.data_plane.streaming_ref_channel import StreamingRefChannel

if channel is not None:
producer_channel = channel
path = channel.path
elif ref_channel_path is not None:
path = ref_channel_path
producer_channel = StreamingRefChannel(path)
else:
raise ValueError("provide either `channel` or `ref_channel_path`")
# The producer writes / the consumer reads through SEPARATE channel views over
# the same path (each holds its own read/write offset) — the same split the
# cross-process disagg path uses, here colocated in one process.
consumer_channel = StreamingRefChannel(path)

# One process: share a single metadata store instance across both halves so
# the producer's commits and the consumer's acks land in the same store.
# (A metadata_db_path instead opens one SQLite connection per half over the
# same file — durable, and what a restart-reconcilable run needs.)
shared_store = metadata_store
if shared_store is None and metadata_db_path is None:
shared_store = InMemoryMetadataStore()

_workers, drive_producer = build_disagg_online_producer(
target_model=target_model,
prompts=prompts,
feature_store=feature_store,
channel=producer_channel,
run_id=run_id,
target_hidden_size=target_hidden_size,
target_vocab_size=target_vocab_size,
draft_vocab_size=draft_vocab_size,
target_repr=target_repr,
aux_hidden_state_layer_ids=aux_hidden_state_layer_ids,
vocab_map_version=vocab_map_version,
t2d=t2d,
num_rollout_workers=num_rollout_workers,
device=device,
lease=lease,
in_flight_high_watermark=in_flight_high_watermark,
metadata_store=shared_store,
metadata_db_path=metadata_db_path,
)
trainer, loader = build_disagg_online_consumer(
feature_store=feature_store,
channel=consumer_channel,
eagle3_model=eagle3_model,
optimizer_factory=optimizer_factory,
run_id=run_id,
output_dir=output_dir,
batch_size=batch_size,
accumulation_steps=accumulation_steps,
num_epochs=num_epochs,
max_steps=max_steps,
save_interval=save_interval,
eval_interval=eval_interval,
tp_size=tp_size,
sp_ulysses_size=sp_ulysses_size,
sp_ring_size=sp_ring_size,
collate_fn=collate_fn,
idle_timeout_s=idle_timeout_s,
metadata_store=shared_store,
metadata_db_path=metadata_db_path,
resume=resume,
logger=logger,
log_interval=log_interval,
)

def run() -> int:
return run_disagg_online_interleaved(
trainer=trainer,
loader=loader,
drive_producer=drive_producer,
channel=producer_channel,
join_timeout_s=join_timeout_s,
)

return trainer, loader, run


# Backward-compatible alias for early branch users.
build_offline_eagle3_controller = build_offline_eagle3_runtime

Expand All @@ -639,4 +871,6 @@ def build_disagg_online_consumer(
"build_online_eagle3_runtime",
"build_disagg_online_producer",
"build_disagg_online_consumer",
"build_disagg_online_eagle3_runtime",
"run_disagg_online_interleaved",
]
Loading
Loading