Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions tests/generate/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,69 @@ def test_sglang_jax_1d_kv_bias_alignment(self):
self.assertTrue(jnp.allclose(result.params[src_key], expected))


def test_transfer_state_directly_fuses_moe_weights(self):
"""Tests that wi_0 and wi_1 are fused into wi when target expects it."""
wi_0_val = jnp.array([[1.0, 2.0], [5.0, 6.0]], dtype=jnp.float32)
wi_1_val = jnp.array([[3.0, 4.0], [7.0, 8.0]], dtype=jnp.float32)

src_state = nnx.Dict(
layers=nnx.Dict(
wi_0=nnx.Param(wi_0_val),
wi_1=nnx.Param(wi_1_val),
)
)

dst_state = nnx.Dict(
layers=nnx.Dict(
wi=nnx.Param(jnp.zeros((2, 4), dtype=jnp.float32))
)
)

mock_reshard = lambda source, target: source
utils.transfer_state_directly(src_state, dst_state, reshard_fn=mock_reshard)

expected_wi = jnp.concatenate([wi_0_val, wi_1_val], axis=-1)
np.testing.assert_array_equal(
dst_state['layers']['wi'][...],
expected_wi,
)

def test_transfer_state_directly_fuses_moe_weights_scanned_to_unrolled(self):
"""Scanned wi_0/wi_1 are unstacked and fused into per-layer wi (unrolled dst)."""
# 2 layers, 2 experts, 2 features each -> fused shape [2, 4] per layer
wi_0_val = jnp.array(
[[[1., 2.], [5., 6.]], [[10., 20.], [50., 60.]]], dtype=jnp.float32
) # [num_layers=2, experts=2, features=2]
wi_1_val = jnp.array(
[[[3., 4.], [7., 8.]], [[30., 40.], [70., 80.]]], dtype=jnp.float32
)

src_state = nnx.Dict(
layers=nnx.Dict(
wi_0=nnx.Param(wi_0_val),
wi_1=nnx.Param(wi_1_val),
)
)
dst_state = nnx.Dict(**{
'layers_0': nnx.Dict(wi=nnx.Param(jnp.zeros((2, 4), dtype=jnp.float32))),
'layers_1': nnx.Dict(wi=nnx.Param(jnp.zeros((2, 4), dtype=jnp.float32))),
})

mock_reshard = lambda source, target: source
utils.transfer_state_directly(
src_state, dst_state, reshard_fn=mock_reshard, scan_axis=0
)

np.testing.assert_array_equal(
dst_state['layers_0']['wi'][...],
jnp.concatenate([wi_0_val[0], wi_1_val[0]], axis=-1),
)
np.testing.assert_array_equal(
dst_state['layers_1']['wi'][...],
jnp.concatenate([wi_0_val[1], wi_1_val[1]], axis=-1),
)


class ResolveParallelismSizesTest(parameterized.TestCase):

def _make_mesh(self, total_devices):
Expand Down
195 changes: 160 additions & 35 deletions tunix/generate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,6 @@ def _repeat_to_model_shape(
for axis, (src_dim, tgt_dim) in enumerate(zip(src_shape, tgt_shape)):
if tgt_dim != src_dim:
result = jnp.repeat(result, tgt_dim // src_dim, axis=axis)

return result


Expand All @@ -1014,12 +1013,101 @@ def _delete_buffers(x):
jax.tree_util.tree_map(_delete_buffers, pytree)


@functools.partial(jax.jit, static_argnums=(2, 3))
def _jit_fuse_and_unstack_moe(
wi_0: jax.Array,
wi_1: jax.Array,
scan_axis: int,
num_layers: int,
) -> tuple[jax.Array, ...]:
"""Fuses wi_0/wi_1 along last axis, then unstacks along scan_axis.

By combining concatenation and unstacking under jax.jit, XLA can fuse both
ops and avoid materializing the full concatenated intermediate tensor on
device. scan_axis and num_layers are static so XLA knows the output tuple
size at compile time and can unroll the unstack at trace time.

Args:
wi_0: First MoE gate weight, shape [num_layers, experts, features].
wi_1: Second MoE gate weight, shape [num_layers, experts, features].
scan_axis: The axis along which layers are stacked (typically 0).
num_layers: Number of layers (must match wi_0.shape[scan_axis]).

Returns:
A tuple of num_layers fused per-layer arrays, each with shape
[experts, 2 * features].
"""
del num_layers # Only used to make this a static arg for JIT cache keying.
fused = jnp.concatenate([wi_0, wi_1], axis=-1)
return jnp.unstack(fused, axis=scan_axis)


def _fuse_moe_weights(src_flat: Dict[Tuple[str, ...], Any], tgt_flat: Dict[Tuple[str, ...], Any]) -> Dict[Tuple[str, ...], Any]:
"""Fuses wi_0 and wi_1 into wi if the target model expects fused MoE weights."""
new_src_flat = dict(src_flat)
for tgt_key in tgt_flat.keys():
if tgt_key and tgt_key[-1] == 'wi':
wi_0_key = tgt_key[:-1] + ('wi_0',)
wi_1_key = tgt_key[:-1] + ('wi_1',)
if wi_0_key in new_src_flat and wi_1_key in new_src_flat:
logging.info("Fusing MoE weights for %s", tgt_key)
wi_0 = new_src_flat.pop(wi_0_key)
wi_1 = new_src_flat.pop(wi_1_key)
new_src_flat[tgt_key] = jnp.concatenate([wi_0, wi_1], axis=-1)
del wi_0, wi_1 # Release references; .pop() already removed from dict.
return new_src_flat


def _reshard_in_chunks(
src_flat: Dict[Tuple[str, ...], Any],
spec_flat: Dict[Tuple[str, ...], Any],
reshard_fn: Callable[..., Mapping[str, Any]],
chunk_size: int,
) -> Dict[Tuple[str, ...], Any]:
"""Reshards a flat weight dict in sequential chunks to reduce peak HBM pressure.

Instead of issuing one large jax.device_put for the entire model, this helper
splits the flat key-value dict into groups of `chunk_size` keys and reshards
each group independently. Between groups it calls jax.block_until_ready() so
that the XLA allocator can reclaim the source buffers before committing the
next chunk, keeping the peak contiguous allocation requirement proportional to
chunk_size rather than the full model size.

Args:
src_flat: Flat dict mapping key tuples to source JAX arrays.
spec_flat: Flat dict mapping the same key tuples to target-sharded arrays
(used by reshard_fn to determine destination shardings).
reshard_fn: Callable with the same signature as reshard_pytree, i.e.
reshard_fn(source=<nested dict>, target=<nested dict>).
chunk_size: Maximum number of flat keys to process per reshard call.

Returns:
A flat dict with the same keys as src_flat, containing resharded arrays.
"""
keys = list(src_flat.keys())
resharded: Dict[Tuple[str, ...], Any] = {}
for start in range(0, len(keys), chunk_size):
chunk_keys = keys[start : start + chunk_size]
chunk_src = traverse_util.unflatten_dict(
{k: src_flat[k] for k in chunk_keys}
)
chunk_spec = traverse_util.unflatten_dict(
{k: spec_flat[k] for k in chunk_keys}
)
chunk_resharded = reshard_fn(source=chunk_src, target=chunk_spec)
jax.block_until_ready(chunk_resharded)
resharded.update(traverse_util.flatten_dict(chunk_resharded))
del chunk_src, chunk_resharded
return resharded


def transfer_state_directly(
src_state: Mapping[str, Any],
dst_state: Mapping[str, Any],
reshard_fn: Callable[..., Mapping[str, Any]],
scan_axis: int = 1,
delete_dst_buffers: bool = False,
reshard_chunk_size: Optional[int] = None,
) -> None:
"""Transfers state directly by matching structure, stripping wrappers.

Expand All @@ -1035,12 +1123,18 @@ def transfer_state_directly(
dst_state: The destination state to transfer to.
reshard_fn: A function to shard the values.
scan_axis: The axis along which to unroll scanned layers, if needed.
delete_dst_buffers: Whether to delete buffers in the destination state after transfer to save memory.
delete_dst_buffers: Whether to delete buffers in the destination state after
transfer to save memory.
reshard_chunk_size: When set, the final reshard is split into sequential
groups of this many flat keys instead of one monolithic call. This reduces
peak contiguous HBM pressure, which prevents XLA allocator fragmentation
errors on large models. A value of 50 (≈5-10 transformer layers) is a
reasonable starting point. When None (default) the original single-call
behavior is preserved.
"""

if delete_dst_buffers:
_delete_pytree_buffers(dst_state)
gc.collect()

def safe_has_key(obj: Mapping[str, Any], key: str) -> bool:
if isinstance(obj, dict):
Expand Down Expand Up @@ -1098,6 +1192,8 @@ def intersect_trees(
src_flat = traverse_util.flatten_dict(src)
tgt_flat = traverse_util.flatten_dict(tgt_spec)

src_flat = _fuse_moe_weights(src_flat, tgt_flat)

filtered_src_flat = {}
filtered_tgt_flat = {}

Expand All @@ -1106,9 +1202,6 @@ def intersect_trees(

layer_pattern = re.compile(r'^layers_(\d+)$')

# Cache to store unstacked scanned arrays to avoid repeated work
unstacked_cache = {}

for key_tuple, tgt_val in tgt_flat.items():
# Try Direct Match
if key_tuple in src_flat:
Expand Down Expand Up @@ -1153,39 +1246,61 @@ def intersect_trees(
break

if found_candidate:
# Apply the dtype cast and the repeating *before* unstacking
if found_candidate not in unstacked_cache:
src_val = src_flat[found_candidate]

# Cast the bulk tensor once
# Cast the bulk tensor once before unstacking.
src_val = _apply_dtype_cast(src_val, tgt_val.dtype, str(found_candidate))

# Predict the stacked target shape and repeat the bulk tensor once
src_shape = getattr(src_val, 'shape', None)
tgt_shape = getattr(tgt_val, 'shape', None)

if src_shape and tgt_shape and len(src_shape) == len(tgt_shape) + 1:
# Construct the 3D target shape (e.g., [layers, global_heads, dim])
stacked_tgt_shape = tgt_shape[:scan_axis] + (src_shape[scan_axis],) + tgt_shape[scan_axis:]

# Mock a target array purely to pass the shape to our repeat helper
class _MockTarget:
shape = stacked_tgt_shape

src_val = _repeat_to_model_shape(src_val, _MockTarget(), str(found_candidate))

# Unstack the already casted and repeated tensor using the provided scan_axis
unstacked_cache[found_candidate] = _unstack_scanned_param(
src_val, tgt_val, str(found_candidate), scan_axis=scan_axis
)

# Extract the layer_idx-th element from the unstacked cache
sliced_val = unstacked_cache[found_candidate][layer_idx]

# Extract the layer_idx-th element from the unstacked cache.
sliced_val = unstacked_cache[found_candidate][layer_idx]
# Apply KV-head repeat per-slice after unstacking (avoids _MockTarget hack).
sliced_val = _repeat_to_model_shape(sliced_val, tgt_val, str(key_tuple))
filtered_src_flat[key_tuple] = sliced_val
filtered_tgt_flat[key_tuple] = tgt_val
continue

# MoE fusion case: target has 'layers_X/.../wi' but source has scanned
# 'layers/.../wi_0' and 'layers/.../wi_1'. Fuse the full stacked
# tensors first, then unstack once via a JIT-compiled helper — avoids
# N per-layer jnp.concatenate dispatches and 2N intermediate device
# allocations that cause compilation pressure and memory fragmentation.
if key_tuple and key_tuple[-1] == 'wi':
scanned_prefix = (
key_tuple[:match_index] + ('layers',) + key_tuple[match_index + 1:-1]
)
wi_0_key = scanned_prefix + ('wi_0',)
wi_1_key = scanned_prefix + ('wi_1',)

if wi_0_key in src_flat and wi_1_key in src_flat:
# Use a synthetic cache key for the pre-fused scanned tensor so it
# is computed only once across all layer indices.
fused_scanned_key = scanned_prefix + ('wi_fused',)
if fused_scanned_key not in unstacked_cache:
logging.info(
'Fusing scanned MoE weights for %s', scanned_prefix
)
wi_0_full = _apply_dtype_cast(
src_flat[wi_0_key], tgt_val.dtype, str(wi_0_key)
)
wi_1_full = _apply_dtype_cast(
src_flat[wi_1_key], tgt_val.dtype, str(wi_1_key)
)
num_layers = src_flat[wi_0_key].shape[scan_axis]
# Single JIT-compiled fusion+unstack: XLA fuses concat and
# unstack into one program, avoiding a materialized intermediate.
unstacked_cache[fused_scanned_key] = _jit_fuse_and_unstack_moe(
wi_0_full, wi_1_full, scan_axis, num_layers
)
del wi_0_full, wi_1_full # Release references promptly.

sliced_val = unstacked_cache[fused_scanned_key][layer_idx]
filtered_src_flat[key_tuple] = sliced_val
filtered_tgt_flat[key_tuple] = tgt_val
continue

# Unflatten back to nested structure
return (
traverse_util.unflatten_dict(filtered_src_flat),
Expand All @@ -1200,15 +1315,25 @@ class _MockTarget:
final_source, final_spec = intersect_trees(full_source_dict, full_target_spec)

# Reshard and Update
resharded_weights = reshard_fn(
source=final_source,
target=final_spec,
)
if reshard_chunk_size is not None:
# Chunked path: split the flat weight dict into groups of reshard_chunk_size
# keys and reshard each group independently. This keeps peak contiguous HBM
# allocation proportional to chunk_size, avoiding XLA fragmentation errors
# on large models without needing to clear the compilation cache.
src_flat = traverse_util.flatten_dict(final_source)
spec_flat = traverse_util.flatten_dict(final_spec)
del final_source, final_spec
resharded_flat = _reshard_in_chunks(
src_flat, spec_flat, reshard_fn, reshard_chunk_size
)
resharded_weights = traverse_util.unflatten_dict(resharded_flat)
else:
resharded_weights = reshard_fn(
source=final_source,
target=final_spec,
)
nnx.update(dst_state, resharded_weights)

# Explicitly free memory
gc.collect()


def resolve_parallelism_sizes(
mesh: jax.sharding.Mesh,
Expand Down
6 changes: 3 additions & 3 deletions tunix/generate/vllm_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class VllmConfig:
data_parallel_size: int = -1
tensor_parallel_size: int = -1
expert_parallel_size: int = 1
reshard_chunk_size: Optional[int] = 50

# vLLM engine args that can be directly passed in without additional processing, e.g. max_model_len, async_scheduling, etc.
engine_kwargs: dataclasses.InitVar[Optional[Dict[str, Any]]] = None
Expand Down Expand Up @@ -190,9 +191,7 @@ def update_params(
self._driver.llm_engine.reset_prefix_cache()
self._driver.llm_engine.collective_rpc("delete_kv_cache")

# Perform explicit garbage collection and synchronization to free up HBM memory before loading new weights
gc.collect()
jax.clear_caches()
# Synchronization point before weight sync
jax.effects_barrier()

if self.to_hf_key_mappings:
Expand Down Expand Up @@ -236,6 +235,7 @@ def update_params(
dst_state=self.transformer_state,
reshard_fn=reshard.reshard_pytree,
delete_dst_buffers=True, # Ensure old weights are deleted to free up HBM memory
reshard_chunk_size=self.config.reshard_chunk_size,
)

if self.llm is not None:
Expand Down
Loading