Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
291 changes: 250 additions & 41 deletions specforge/runtime/data_plane/mooncake_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,54 @@ def _connect_store(setup_kwargs: Dict[str, Any]) -> Any:
return store


# FeatureSpec.dtype (a string) -> torch dtype, for allocating zero-copy receive
# tensors from the ref alone (the ref carries shape+dtype, so get() needs no
# serialized header).
_TORCH_DTYPES = {
"float32": torch.float32,
"float64": torch.float64,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"int64": torch.int64,
"int32": torch.int32,
"int16": torch.int16,
"int8": torch.int8,
"uint8": torch.uint8,
"bool": torch.bool,
}


def _alloc_from_spec(spec) -> torch.Tensor:
"""A fresh contiguous tensor matching a FeatureSpec (the zero-copy dst)."""
dtype = _TORCH_DTYPES.get(spec.dtype)
if dtype is None:
raise KeyError(f"unsupported feature dtype {spec.dtype!r} for zero-copy get")
return torch.empty(tuple(int(d) for d in spec.shape), dtype=dtype)


def _nbytes(t: torch.Tensor) -> int:
return t.numel() * t.element_size()


class MooncakeFeatureStore(FeatureStore):
"""A disaggregated :class:`FeatureStore` backed by the Mooncake store.

One Mooncake object per sample (``{store_id}/{sample_id}``) holds a
``torch.save``'d ``{"generation": int, "tensors": dict}`` blob, hard-pinned so
Mooncake's LRU never reclaims a live feature. ``store`` may be injected (any
object exposing ``is_exist/put/get/remove`` with the Mooncake signatures) so
the contract is unit-testable without a running master.
**Zero-copy transport (default).** One hard-pinned Mooncake object per
*tensor*, keyed ``{store_id}/{sample_id}/g{gen}/{name}``. ``put()`` writes each
tensor straight from its storage with ``put_from(ptr)``; ``get()`` reads each
straight into a tensor allocated from the ref's ``FeatureSpec`` with
``get_into(ptr)``. There is no ``torch.save``/``torch.load`` pickle round-trip
on the hot path — shape/dtype travel on the ref, the bytes are the raw tensor
buffer. The generation lives in the key (like ``SharedDirFeatureStore``'s
filename generation), so a re-put supersedes the old key set and a stale ref's
keys are gone -> ``get()`` raises (B5).

Set ``zero_copy=False`` (or inject a backend without ``put_from``/``get_into``)
to fall back to the single-object ``torch.save`` blob path.

``store`` may be injected (any object exposing the Mooncake method subset:
``is_exist/remove`` plus either ``put_from``/``get_into`` or ``put``/``get``)
so the contract is unit-testable without a running master.
"""

def __init__(
Expand All @@ -147,6 +187,7 @@ def __init__(
max_release_attempts: int = 3,
replica_num: int = 1,
hard_pin: bool = True,
zero_copy: bool = True,
clock: Callable[[], float] = time.monotonic,
) -> None:
self.auth = auth or AuthPolicy()
Expand All @@ -159,6 +200,16 @@ def __init__(
store = _connect_store(kw)
self._store = store
self._put_config = _build_replicate_config(replica_num, hard_pin)
# Zero-copy transport: one Mooncake object per *tensor*, written straight
# from the tensor's storage via put_from(ptr) and read straight into a
# spec-allocated tensor via get_into(ptr) -- no torch.save/torch.load
# pickle round-trip. Falls back to the pickle path if the backend lacks
# the raw-buffer API (older mooncake / a fake without it).
self._zero_copy = (
bool(zero_copy)
and callable(getattr(store, "put_from", None))
and callable(getattr(store, "get_into", None))
)
self.max_resident_bytes = max_resident_bytes
self.max_hold_age_s = max_hold_age_s
# Offline re-iterable mode: release() must NOT free (multi-epoch); mirrors
Expand All @@ -170,6 +221,11 @@ def __init__(
self._generation: Dict[str, int] = {}
self._put_time: Dict[str, float] = {}
self._sample_bytes: Dict[str, int] = {}
# feature names per resident sample -> the per-tensor keys to remove on
# free (zero-copy mode). Cached on both put() (producer) and get()
# (consumer) so each side can free the sample it owns/consumed without the
# ref in hand at release() time.
self._sample_names: Dict[str, List[str]] = {}
self._active_leases: Dict[str, FeatureHandle] = {}
# samples whose remote remove() failed; retried/force-freed by gc()
self._release_pending: Dict[str, int] = {}
Expand All @@ -190,6 +246,12 @@ def __init__(
def _key(self, sample_id: str) -> str:
return f"{self.store_id}/{sample_id}"

def _tkey(self, sample_id: str, gen: int, name: str) -> str:
# generation lives in the key (like SharedDirFeatureStore's filename gen):
# a re-put writes a new-gen key set and removes the old, so a stale ref's
# keys are gone -> get() raises (B5), no payload-carried gen needed.
return f"{self.store_id}/{sample_id}/g{gen}/{name}"

# -- store wrappers (status-code aware) --------------------------------
def _store_exists(self, key: str) -> bool:
return int(self._store.is_exist(key)) == 1
Expand All @@ -199,6 +261,62 @@ def _store_put(self, key: str, value: bytes) -> None:
if rc is not None and int(rc) != 0:
raise RuntimeError(f"mooncake put failed (status {rc}) for {key}")

def _store_put_tensor(self, key: str, t: torch.Tensor) -> None:
"""Zero-copy publish: DMA straight from the tensor's storage, hard-pinned.

``t`` must be contiguous + CPU (caller stages it). No torch.save: the
bytes are the raw tensor buffer; shape/dtype travel on the ref's
FeatureSpec, so get() needs no header. The source is registered with the
transfer engine for the duration of the put -- RDMA transfers it by DMA
and rejects an unregistered address (AddressNotRegistered); TCP ignores
the registration.
"""
nb = _nbytes(t)
try:
self._store.register_buffer(t.data_ptr(), nb)
except Exception: # pragma: no cover - some builds auto-register
pass
try:
rc = self._store.put_from(key, t.data_ptr(), nb, self._put_config)
finally:
try:
self._store.unregister_buffer(t.data_ptr())
except Exception: # pragma: no cover
pass
if rc is not None and int(rc) < 0:
raise RuntimeError(f"mooncake put_from failed (status {rc}) for {key}")

def _store_get_tensor(self, key: str, out: torch.Tensor) -> None:
"""Zero-copy fetch into a pre-allocated tensor. Raises KeyError if absent.

The receive buffer is registered with the transfer engine for the get_into
(required by the raw-buffer path), then unregistered.
"""
nb = _nbytes(out)
try:
self._store.register_buffer(out.data_ptr(), nb)
except Exception: # pragma: no cover - some builds auto-register
pass
try:
rc = self._store.get_into(key, out.data_ptr(), nb)
finally:
try:
self._store.unregister_buffer(out.data_ptr())
except Exception: # pragma: no cover
pass
if rc is None or int(rc) < 0:
raise KeyError(f"mooncake get_into failed (status {rc}) for {key}")
# get_into returns the number of bytes read; a full read returns exactly
# nb. A short read (0 <= rc < nb) would leave the tail of this freshly
# torch.empty'd buffer as uninitialized garbage -- and unlike the pickle
# path (torch.load reconstructs whole tensors) the raw-buffer path cannot
# otherwise detect under-fill. Reject it rather than hand the trainer
# silently-corrupt data (B5: never serve wrong bytes).
if int(rc) != nb:
raise KeyError(
f"mooncake get_into short read for {key}: got {rc} of {nb} bytes"
)

def _store_remove(self, key: str) -> bool:
"""Best-effort physical free. Returns True on confirmed removal."""
try:
Expand All @@ -218,9 +336,9 @@ def put(
self.auth.check(self._credential)
if not tensors:
raise ValueError("put requires at least one tensor")
staged = {k: v.detach().cpu() for k, v in tensors.items()}
staged = {k: v.detach().cpu().contiguous() for k, v in tensors.items()}
specs = {k: spec_from_tensor(k, v) for k, v in staged.items()}
nbytes = sum(t.numel() * t.element_size() for t in staged.values())
nbytes = sum(_nbytes(t) for t in staged.values())
with self._lock:
if (
self.max_resident_bytes is not None
Expand All @@ -232,23 +350,50 @@ def put(
)
self._gen_counter += 1
gen = self._gen_counter
buf = io.BytesIO()
torch.save({"generation": gen, "tensors": staged}, buf)
key = self._key(sample_id)
# Overwrite-safe publish: a re-put bumps the generation. remove() first so
# the hard-pinned prior blob is released rather than orphaned; if that
# remove fails the old (pinned) blob may leak, so surface it loudly.
if self._store_exists(key) and not self._store_remove(key):
logger.warning(
"MooncakeFeatureStore re-put of %s: removing the stale blob failed; "
"a hard-pinned object may be orphaned",
key,
)
self._store_put(key, buf.getvalue())
prior_gen = self._generation.get(sample_id)
prior_names = self._sample_names.get(sample_id, [])
if self._zero_copy:
# One hard-pinned object per tensor, DMA'd straight from its storage.
# staged keeps the source tensors alive across the synchronous puts.
for name, t in staged.items():
self._store_put_tensor(self._tkey(sample_id, gen, name), t)
# Overwrite-safe: drop the prior generation's tensor keys so a stale
# ref's keys are gone (its get() then raises -> no use-after-free).
if prior_gen is not None and prior_gen != gen:
leaked = [
name
for name in prior_names
if not self._store_remove(self._tkey(sample_id, prior_gen, name))
]
if leaked:
logger.warning(
"MooncakeFeatureStore re-put of %s gen %s: removing prior "
"generation %s tensors %s failed; hard-pinned objects may be "
"orphaned (and the stale ref stays readable until reclaimed)",
sample_id,
prior_gen,
prior_gen,
leaked,
)
else:
buf = io.BytesIO()
torch.save({"generation": gen, "tensors": staged}, buf)
key = self._key(sample_id)
# Overwrite-safe publish: a re-put bumps the generation. remove() first
# so the hard-pinned prior blob is released rather than orphaned; if
# that remove fails the old (pinned) blob may leak, so surface it.
if self._store_exists(key) and not self._store_remove(key):
logger.warning(
"MooncakeFeatureStore re-put of %s: removing the stale blob "
"failed; a hard-pinned object may be orphaned",
key,
)
self._store_put(key, buf.getvalue())
with self._lock:
self._generation[sample_id] = gen
self._put_time[sample_id] = self._clock()
self._sample_bytes[sample_id] = nbytes
self._sample_names[sample_id] = list(staged)
return SampleRef(
sample_id=sample_id,
run_id=str(metadata.get("run_id", "unknown")),
Expand Down Expand Up @@ -288,9 +433,60 @@ def get(
f"sample {sid} generation {ref_gen} was released/aborted; "
f"refusing use-after-free"
)
wanted = names or list(sample_ref.feature_keys.keys())
if self._zero_copy:
out, gen = self._get_zero_copy(sample_ref, wanted)
else:
out, gen = self._get_pickle(sample_ref, wanted)
if str(device) != "cpu":
out = {k: v.to(device) for k, v in out.items()}
with self._lock:
self._counter += 1
# Consumer-side cache: a process that only get()s a sample (never
# put() it) still needs gen + feature names so its release()/abort()
# can free the per-tensor keys. setdefault keeps the producer's own
# entries authoritative when producer and consumer are one instance.
self._generation.setdefault(sid, gen)
self._sample_names.setdefault(sid, list(sample_ref.feature_keys.keys()))
handle = FeatureHandle(
sample_id=sid,
generation=gen,
lease_token=f"{sid}:{self._counter}",
)
self._active_leases[handle.lease_token] = handle
return out, handle

def _get_zero_copy(
self, ref: SampleRef, wanted: List[str]
) -> Tuple[Dict[str, torch.Tensor], int]:
"""Read each feature straight into a spec-allocated tensor (no pickle)."""
sid = ref.sample_id
gen = ref.metadata.get("generation")
if gen is None:
raise KeyError(f"sample {sid} ref carries no generation; cannot locate")
gen = int(gen)
out: Dict[str, torch.Tensor] = {}
for n in wanted:
spec = ref.feature_specs.get(n)
if spec is None:
raise KeyError(f"sample {sid} ref has no spec for feature {n!r}")
key = self._tkey(sid, gen, n)
if not self._store_exists(key):
# freed (release/abort), superseded by a re-put, or never written
raise KeyError(
f"sample {sid} gen {gen} feature {n!r} not available "
f"(freed, stale, or never written)"
)
out[n] = _alloc_from_spec(spec) # fresh -> clone-on-fetch for free (B5)
self._store_get_tensor(key, out[n])
return out, gen

def _get_pickle(
self, ref: SampleRef, wanted: List[str]
) -> Tuple[Dict[str, torch.Tensor], int]:
sid = ref.sample_id
key = self._key(sid)
if not self._store_exists(key):
# freed by release/abort, or never written -> never hand back stale
raise KeyError(f"sample {sid} not available in store {self.store_id}")
value = self._store.get(key)
if not value:
Expand All @@ -301,47 +497,60 @@ def get(
payload = torch.load(io.BytesIO(value), weights_only=True)
on_disk_gen = payload.get("generation")
on_disk_gen = int(on_disk_gen) if on_disk_gen is not None else None
ref_gen = sample_ref.metadata.get("generation", on_disk_gen)
ref_gen = ref.metadata.get("generation", on_disk_gen)
if on_disk_gen is not None and ref_gen != on_disk_gen:
# re-put after this ref was minted -> stale handle
raise KeyError(
f"sample {sid} generation {ref_gen} is stale "
f"(current {on_disk_gen}); refusing use-after-free"
)
raw = payload["tensors"]
wanted = names or list(sample_ref.feature_keys.keys())
out: Dict[str, torch.Tensor] = {}
for n in wanted:
raw_key = sample_ref.feature_keys.get(n, n)
raw_key = ref.feature_keys.get(n, n)
raw_key = raw_key.split("/")[-1] if "/" in raw_key else raw_key
if raw_key not in raw:
raise KeyError(
f"sample {sid} missing key {raw_key!r} for feature {n!r}"
)
# clone-on-fetch (B5): returned tensor is independent of the transport
out[n] = raw[raw_key].clone()
if str(device) != "cpu":
out = {k: v.to(device) for k, v in out.items()}
with self._lock:
self._counter += 1
handle = FeatureHandle(
sample_id=sid,
generation=on_disk_gen or 0,
lease_token=f"{sid}:{self._counter}",
)
self._active_leases[handle.lease_token] = handle
return out, handle
out[n] = raw[raw_key].clone() # clone-on-fetch (B5)
return out, (on_disk_gen or 0)

# -- lifetime ----------------------------------------------------------
def _try_physical_free(self, sample_id: str) -> bool:
"""Remove the remote object. False on a (retryable) RPC failure."""
return self._store_remove(self._key(sample_id))
"""Remove the remote object(s). False on a (retryable) RPC failure.

Zero-copy: one object per tensor, so remove every per-tensor key of the
sample's current generation. Pickle: a single object.
"""
if not self._zero_copy:
return self._store_remove(self._key(sample_id))
gen = self._generation.get(sample_id)
if gen is None:
return True # nothing tracked to remove (already freed)
ok = True
for name in self._sample_names.get(sample_id, []):
if not self._store_remove(self._tkey(sample_id, gen, name)):
ok = False
return ok

def _sample_exists(self, sample_id: str) -> bool:
"""True if any object backing the sample's current generation is present."""
if not self._zero_copy:
return self._store_exists(self._key(sample_id))
gen = self._generation.get(sample_id)
if gen is None:
return False
return any(
self._store_exists(self._tkey(sample_id, gen, n))
for n in self._sample_names.get(sample_id, [])
)

def _free_bookkeeping_locked(self, sample_id: str) -> int:
"""Drop in-process tracking for a sample. Returns bytes accounted freed."""
nbytes = self._sample_bytes.pop(sample_id, 0)
self._generation.pop(sample_id, None)
self._put_time.pop(sample_id, None)
self._sample_names.pop(sample_id, None)
self._release_pending.pop(sample_id, None)
return nbytes

Expand Down Expand Up @@ -401,7 +610,7 @@ def gc(self, *, now: Optional[float] = None) -> Dict[str, int]:
self._release_pending.setdefault(sid, 0)
# reconcile release-pending: retry the fallible remote free
for sid in list(self._release_pending):
if not self._store_exists(self._key(sid)):
if not self._sample_exists(sid):
freed_bytes += self._free_bookkeeping_locked(sid)
continue
attempts = self._release_pending[sid] + 1
Expand Down
Loading
Loading