diff --git a/MaxText/layers/deepseek.py b/MaxText/layers/deepseek.py index f59e68146..e363e2bbe 100644 --- a/MaxText/layers/deepseek.py +++ b/MaxText/layers/deepseek.py @@ -32,7 +32,7 @@ from MaxText.layers import linears from MaxText.common_types import Config from MaxText.layers.normalizations import rms_norm -from MaxText.layers import moe +from MaxText.layers import moe_linen from MaxText.layers import quantizations from MaxText.layers.quantizations import AqtQuantization as Quant from MaxText.inference import page_manager @@ -257,7 +257,7 @@ def __call__( # NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints. # The `name` represents the weight name in JAX/checkpoints and so the class name # is just for readability. - mlp_lnx = moe.RoutedAndSharedMoE( + mlp_lnx = moe_linen.get_routed_and_shared_moe( name="DeepSeekMoeBlock_0", config=cfg, mesh=self.mesh, diff --git a/MaxText/layers/llama4.py b/MaxText/layers/llama4.py index 256f3eb8f..50643341e 100644 --- a/MaxText/layers/llama4.py +++ b/MaxText/layers/llama4.py @@ -32,7 +32,7 @@ from MaxText.layers import initializers from MaxText.layers.linears import mlp_block from MaxText.layers import linears -from MaxText.layers import moe +from MaxText.layers import moe_linen from MaxText.layers import quantizations from MaxText.layers import attentions from MaxText.layers.attentions import AttentionType @@ -472,7 +472,7 @@ def __call__( # NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints. # The `name` represents the weight name in JAX/checkpoints and so the class name # is just for readability. - mlp_lnx = moe.RoutedAndSharedMoE( + mlp_lnx = moe_linen.get_routed_and_shared_moe( name="Llama4MoEBlock_0", config=cfg, mesh=self.mesh, diff --git a/MaxText/layers/mixtral.py b/MaxText/layers/mixtral.py index d5dc65b23..6d397d1a2 100644 --- a/MaxText/layers/mixtral.py +++ b/MaxText/layers/mixtral.py @@ -29,7 +29,7 @@ from MaxText.layers import initializers from MaxText.layers import models -from MaxText.layers import moe +from MaxText.layers import moe_linen from MaxText.layers import quantizations from MaxText.layers.attentions import attention_as_linen from MaxText.layers.quantizations import AqtQuantization as Quant @@ -135,7 +135,7 @@ def __call__( # NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints. # The `name` represents the weight name in JAX/checkpoints and so the class name # is just for readability. - mlp_lnx, load_balance_loss = moe.RoutedMoE( + mlp_lnx, load_balance_loss = moe_linen.get_routed_moe( name="MoeBlock_0", config=cfg, num_experts=cfg.num_experts, diff --git a/MaxText/layers/moe.py b/MaxText/layers/moe.py index 31c123583..612db1678 100644 --- a/MaxText/layers/moe.py +++ b/MaxText/layers/moe.py @@ -13,7 +13,7 @@ # limitations under the License. -"""MoE related Layers.""" +"""MoE related NNX Layers.""" import enum import functools @@ -22,20 +22,20 @@ from aqt.jax.v2 import aqt_tensor as aqt import flax.linen as nn +from flax import nnx import jax from jax import ad_checkpoint as adc from jax.experimental import shard_map from jax.experimental import xla_metadata import jax.numpy as jnp +import numpy as np + from MaxText import common_types as ctypes from MaxText import max_logging from MaxText import max_utils from MaxText.kernels import megablox as mblx -from MaxText.layers import attentions -from MaxText.layers import initializers -from MaxText.layers import linears -from MaxText.layers import quantizations -import numpy as np +from MaxText.layers import attentions, linears, quantizations, nnx_wrappers +from MaxText.layers.initializers import NdInitializer, nd_dense_init, default_bias_init set_xla_metadata = xla_metadata.set_xla_metadata @@ -72,74 +72,122 @@ def random_routing(rng_key, gate_logits, num_experts_per_tok): return top_k_weights, top_k_indices -class GateLogit(nn.Module): - """A layer used to compute gate logits, allowing to return the pre bias values for DeepSeek routing. - - Attributes: - features: tuple with numbers of output features. - model_name: which model to run. - axis: tuple with axes to apply the transformation on. - weight_dtype: the dtype of the weights (default: float32). - dtype: the dtype of the computation (default: float32). - kernel_init: initializer function for the weight matrix. - kernel_axes: tuple with axes to apply kernel function. - use_bias: whether to add learnable bias in gate logit scores. When enabled, - this bias aids expert load balancing (like in DeepSeek V3), and is not - part of the loss calculation. - score_func: scoring function for output normalization before applying bias. - quant: quantization config, defaults to None implying no quantization. - matmul_precision: precision for JAX functions. - """ +class GateLogit(nnx.Module): + """A layer used to compute gate logits, allowing to return the pre bias values for DeepSeek routing.""" - features: Union[Iterable[int], int] - model_name: str - axis: Union[Iterable[int], int] = -1 - weight_dtype: ctypes.DType = jnp.float32 - dtype: ctypes.DType = jnp.float32 - kernel_init: attentions.NdInitializer = attentions.nd_dense_init(1.0, "fan_in", "truncated_normal") - kernel_axes: Tuple[Optional[str], ...] = () - use_bias: bool = False - score_func: str = "" - quant: Optional[quantizations.AqtQuantization] = None - matmul_precision: str = "default" + def __init__( + self, + in_features_shape: Union[Iterable[int], int], + out_features_shape: Union[Iterable[int], int], + model_name: str, + rngs: nnx.Rngs, + axis: Union[Iterable[int], int] = -1, + weight_dtype: ctypes.DType = jnp.float32, + dtype: ctypes.DType = jnp.float32, + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes: Tuple[Optional[str], ...] = (), + use_bias: bool = False, + score_func: str = "", + quant: Optional[quantizations.AqtQuantization] = None, + matmul_precision: str = "default", + ): + """Initializes the GateLogit module. + + Attributes: + in_features_shape: The shape of the input features. + out_features_shape: The shape of the output features, typically the number of experts. + model_name: The name of the model. + rngs: An `nnx.Rngs` object used for initializing parameters. + axis: The axis or axes over transformation is applied. + weight_dtype: The data type of the kernel weights. + dtype: The data type for the computation. + kernel_init: The initializer function for the kernel weight matrix. + kernel_axes: A tuple of logical axis names for partitioning the kernel. + use_bias: Whether to add learnable bias in gate logit scores. When enabled, + this bias aids expert load balancing (like in DeepSeek V3), and is not + part of the loss calculation. + score_func: Scoring function for output normalization before applying bias. + quant: The quantization configuration. If None, no quantization is applied. + matmul_precision: The precision level for the matrix multiplication. + """ + self.in_features_shape = linears.canonicalize_tuple(in_features_shape) + self.out_features_shape = linears.canonicalize_tuple(out_features_shape) + self.model_name = model_name + self.axis = linears.canonicalize_tuple(axis) + self.weight_dtype = weight_dtype + self.dtype = dtype + self.kernel_init = kernel_init + self.kernel_axes = kernel_axes + self.use_bias = use_bias + self.score_func = score_func + self.quant = quant + self.matmul_precision = matmul_precision + + # Parameter initialization + kernel_shape = self.in_features_shape + self.out_features_shape + kernel_in_axis = np.arange(len(self.axis)) + kernel_out_axis = np.arange(len(self.axis), len(self.axis) + len(self.out_features_shape)) + + if not quantizations.in_serve_mode(self.quant): + self.kernel = nnx.Param( + self.kernel_init( + rngs.params(), + kernel_shape, + self.weight_dtype, + kernel_in_axis, + kernel_out_axis, + ), + sharding=self.kernel_axes, + ) + + if self.use_bias: + bias_axes = self.kernel_axes[-len(self.out_features_shape) :] + bias_shape = kernel_shape[-len(self.out_features_shape) :] + self.bias = nnx.Param( + default_bias_init(rngs.params(), bias_shape, self.weight_dtype), + sharding=bias_axes, + ) + else: + self.bias = None + + if quant: + dot_general_cls = quant.dot_general_cls(mesh_axes=kernel_axes) + dot_general_linen = dot_general_cls() + quant_dot_general = nnx_wrappers.ToNNX(dot_general_linen, rngs=rngs) + self._quant_dot_general_name = f"{type(dot_general_linen).__name__}_0" + setattr(self, self._quant_dot_general_name, quant_dot_general) + dummy_inputs = jnp.zeros((1, *self.in_features_shape), dtype=self.dtype) + self(dummy_inputs, _initializing=True) + else: + self._quant_dot_general_name = None - @nn.compact - def __call__(self, inputs: ctypes.Array) -> Tuple[ctypes.Array, Optional[ctypes.Array]]: + @property + def quant_dot_general(self) -> nnx_wrappers.ToNNX | None: + if self._quant_dot_general_name is None: + return None + return getattr(self, self._quant_dot_general_name) - features = linears.canonicalize_tuple(self.features) - axis = linears.canonicalize_tuple(self.axis) + def __call__(self, inputs: ctypes.Array, _initializing: bool = False) -> Tuple[ctypes.Array, Optional[ctypes.Array]]: inputs = jnp.asarray(inputs, self.dtype) - axis = linears.normalize_axes(axis, inputs.ndim) + norm_axis = linears.normalize_axes(self.axis, inputs.ndim) - kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features - kernel_in_axis = np.arange(len(axis)) - kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) if quantizations.in_serve_mode(self.quant): - # During aqt convert state we delete kernel weight from `params` to save - # memory and instead retrieve them from the tensors stored in the 'aqt' - # collection. - kernel = jnp.zeros(kernel_shape) + kernel_shape = self.in_features_shape + self.out_features_shape + kernel = jnp.zeros(kernel_shape, dtype=self.dtype) else: - kernel = self.param( - "kernel", - nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), - kernel_shape, - self.weight_dtype, - kernel_in_axis, - kernel_out_axis, - ) + kernel = self.kernel[...] kernel = jnp.asarray(kernel, self.dtype) - contract_ind = tuple(range(0, len(axis))) - output = linears._compute_dot_general( + contract_ind = tuple(range(0, len(norm_axis))) + output = linears._compute_dot_general_nnx( inputs, kernel, - self.kernel_axes, - axis, + norm_axis, contract_ind, self.matmul_precision, - self.quant, + self.quant_dot_general, + _initializing, ) pre_bias_logits = None @@ -149,65 +197,71 @@ def __call__(self, inputs: ctypes.Array) -> Tuple[ctypes.Array, Optional[ctypes. pre_bias_logits = output if self.use_bias: - bias_axes, bias_shape = ( - self.kernel_axes[-len(features) :], - kernel_shape[-len(features) :], - ) - bias = self.param( - "bias", - nn.with_logical_partitioning(initializers.default_bias_init, bias_axes), - bias_shape, - self.weight_dtype, - ) - bias = jnp.asarray(bias, self.dtype) + bias = jnp.asarray(self.bias[...], self.dtype) output += bias return output, pre_bias_logits -class RoutedMoE(nn.Module): - """Implements a routed MoE block. +class RoutedMoE(nnx.Module): + """Implements a routed MoE block.""" - Attributes: - num_experts: Number of experts. - num_experts_per_tok: Number of experts for each token. - mesh: Mesh, device mesh. - kernel_init: Kernel function, passed to the dense layers. - kernel_axes: Tuple with axes to apply kernel function. - intermediate_dim: Intermediate dimension of MoE. - weight_dtype: Type for the weights. - dtype: Type for the dense layer. - quant: Optional quantization config, no quantization if None. - """ - - config: ctypes.Config - num_experts: int - num_experts_per_tok: int - mesh: jax.sharding.Mesh - kernel_init: attentions.NdInitializer - kernel_axes: Tuple[Optional[str], ...] - intermediate_dim: int = 2048 - weight_dtype: ctypes.DType = jnp.float32 - dtype: ctypes.DType = jnp.float32 - quant: Optional[quantizations.AqtQuantization] = None - - # The first axes is expert - wi_kernel_axes = ("exp", "embed_no_exp", "mlp") - wo_kernel_axes = ("exp", "mlp", "embed_no_exp") - - def get_expert_parallelism_size(self): - return self.mesh.shape["expert"] - - def get_tensor_parallelism_size(self): - return self.mesh.shape["tensor"] - - def get_tensor_transpose_parallelism_size(self): - return self.mesh.shape["tensor_transpose"] - - def get_context_autoregressive_parallelism_size(self): - return self.mesh.shape["context_autoregressive"] - - def generate_kernels(self, num_experts, emb_dim, mlp_dim): - """generates kernels.""" + def __init__( + self, + config: ctypes.Config, + num_experts: int, + num_experts_per_tok: int, + mesh: jax.sharding.Mesh, + kernel_init: attentions.NdInitializer, + kernel_axes: Tuple[Optional[str], ...], + rngs: nnx.Rngs, + intermediate_dim: int = 2048, + weight_dtype: ctypes.DType = jnp.float32, + dtype: ctypes.DType = jnp.float32, + quant: Optional[quantizations.AqtQuantization] = None, + ): + """Initializes the RoutedMoE module. + + Attributes: + config: The main config setting. + num_experts: Number of experts. + num_experts_per_tok: Number of experts for each token. + mesh: Mesh, device mesh. + kernel_init: The initializer function for the kernel weight matrix. + kernel_axes: A tuple of logical axis names for partitioning the kernel. + rngs: An `nnx.Rngs` object used for initializing parameters. + intermediate_dim: Intermediate dimension of MoE. + weight_dtype: The data type of the kernel weights. + dtype: The data type for the computation. + quant: The quantization configuration. If None, no quantization is applied. + """ + self.config = config + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.mesh = mesh + self.kernel_init = kernel_init + self.kernel_axes = kernel_axes + self.intermediate_dim = intermediate_dim + self.weight_dtype = weight_dtype + self.dtype = dtype + self.quant = quant + + self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp") + self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp") + + self.gate = GateLogit( + in_features_shape=self.config.emb_dim, + out_features_shape=self.num_experts, + model_name=self.config.model_name, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + kernel_init=self.kernel_init, + kernel_axes=self.kernel_axes, + use_bias=self.config.routed_bias, + score_func=self.config.routed_score_func, + matmul_precision=self.config.matmul_precision, + rngs=rngs, + ) kernel_in_axis = np.arange(1) kernel_out_axis = np.arange(1, 2) @@ -217,51 +271,58 @@ def generate_kernels(self, num_experts, emb_dim, mlp_dim): # During aqt convert state we delete kernel weight from params to save # memory. Instead they are retrieved from the tensors stored in the 'aqt' # collection. - w0_kernel = jnp.zeros((num_experts, emb_dim, mlp_dim)) + self.wi_0 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim)) else: - w0_kernel = self.param( - "wi_0", - nn.with_logical_partitioning(kernel_init, self.wi_kernel_axes), - (num_experts, emb_dim, mlp_dim), - self.weight_dtype, - kernel_in_axis, - kernel_out_axis, + self.wi_0 = nnx.Param( + self.kernel_init( + rngs.params(), + (num_experts, self.config.emb_dim, intermediate_dim), + weight_dtype, + kernel_in_axis, + kernel_out_axis, + ), + sharding=self.wi_kernel_axes, ) - w0_kernel = jnp.asarray(w0_kernel, self.dtype) - if quantizations.in_serve_mode(self.quant): - # During aqt convert state we delete kernel weight from params to save - # memory. Instead they are retrieved from the tensors stored in the 'aqt' - # collection. - w1_kernel = jnp.zeros((num_experts, emb_dim, mlp_dim)) + self.wi_1 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim)) else: - w1_kernel = self.param( - "wi_1", - nn.with_logical_partitioning(kernel_init, self.wi_kernel_axes), - (num_experts, emb_dim, mlp_dim), - self.weight_dtype, - kernel_in_axis, - kernel_out_axis, + self.wi_1 = nnx.Param( + self.kernel_init( + rngs.params(), + (num_experts, self.config.emb_dim, intermediate_dim), + weight_dtype, + kernel_in_axis, + kernel_out_axis, + ), + sharding=self.wi_kernel_axes, ) - w1_kernel = jnp.asarray(w1_kernel, self.dtype) if quantizations.in_serve_mode(self.quant): - # During aqt convert state we delete kernel weight from params to save - # memory. Instead they are retrieved from the tensors stored in the 'aqt' - # collection. - wo_kernel = jnp.zeros((num_experts, mlp_dim, emb_dim)) + self.wo = jnp.zeros((num_experts, intermediate_dim, self.config.emb_dim)) else: - wo_kernel = self.param( - "wo", - nn.with_logical_partitioning(kernel_init, self.wo_kernel_axes), - (num_experts, mlp_dim, emb_dim), - self.weight_dtype, - kernel_in_axis, - kernel_out_axis, + self.wo = nnx.Param( + self.kernel_init( + rngs.params(), + (self.num_experts, self.intermediate_dim, self.config.emb_dim), + self.weight_dtype, + kernel_in_axis, + kernel_out_axis, + ), + sharding=self.wo_kernel_axes, ) - wo_kernel = jnp.asarray(wo_kernel, self.dtype) - return w0_kernel, w1_kernel, wo_kernel + + def get_expert_parallelism_size(self): + return self.mesh.shape["expert"] + + def get_tensor_parallelism_size(self): + return self.mesh.shape["tensor"] + + def get_tensor_transpose_parallelism_size(self): + return self.mesh.shape["tensor_transpose"] + + def get_context_autoregressive_parallelism_size(self): + return self.mesh.shape["context_autoregressive"] def get_topk(self, gate_logits, pre_bias_logits): """get topk.""" @@ -291,7 +352,7 @@ def deepseek_scale_weights(self, weights): weights *= self.config.routed_scaling_factor return weights - def _expert_group_mask(self, gate_logits: jax.Array) -> jax.Array: + def expert_group_mask(self, gate_logits: jax.Array) -> jax.Array: """Returns a mask that selects only the top-k groups of experts. Groups of experts are selected based on the sum of the top-2 expert scores @@ -316,9 +377,7 @@ def _expert_group_mask(self, gate_logits: jax.Array) -> jax.Array: _, group_idx = jax.lax.top_k(group_scores, k=self.config.topk_routing_group) # Mask selected groups so that only those experts are considered. - group_mask = jax.nn.one_hot( - group_idx, num_classes=self.config.n_routing_groups, dtype=jnp.float32 - ) + group_mask = jax.nn.one_hot(group_idx, num_classes=self.config.n_routing_groups, dtype=jnp.float32) group_mask = jnp.sum(group_mask, axis=-2) # Apply masks and get top-k indices. @@ -331,9 +390,7 @@ def _expert_group_mask(self, gate_logits: jax.Array) -> jax.Array: score_mask_expanded.shape[:-2] + (self.num_experts,), ) - def deepseek_routing( - self, gate_logits: jax.Array, pre_bias_logits: jax.Array - ) -> tuple[jax.Array, jax.Array]: + def deepseek_routing(self, gate_logits: jax.Array, pre_bias_logits: jax.Array) -> tuple[jax.Array, jax.Array]: """DeepSeek routing logit. If the configuration does not specify routing groups (`n_routing_groups` is @@ -353,11 +410,7 @@ def deepseek_routing( - top_k_indices: `(..., num_experts_per_tok)` array of indices identifying the selected experts for each token. """ - expert_mask = ( - 1 - if self.config.n_routing_groups == -1 - else self._expert_group_mask(gate_logits) - ) + expert_mask = 1 if self.config.n_routing_groups == -1 else self.expert_group_mask(gate_logits) _, top_k_indices = jax.lax.top_k( jnp.where(expert_mask > 0, gate_logits, -jnp.inf), k=self.num_experts_per_tok, @@ -662,9 +715,7 @@ def gmm(inputs, kernel, group_sizes, expert_assignments): if kernel.bias or kernel.sparsity_mask or len(kernel.scale) > 1: raise ValueError("Unsupported usecase for ragged_dot with quantized kernel.") rhs_inputs = kernel.qvalue - with set_xla_metadata( - ragged_dot_tiling=",".join([str(t) for t in tiling]) - ): + with set_xla_metadata(ragged_dot_tiling=",".join([str(t) for t in tiling])): output = jax.lax.ragged_dot( lhs=inputs, rhs=rhs_inputs, @@ -1415,27 +1466,15 @@ def retrieve_quantized_weight( wo_kernel = max_utils.unbox_logicallypartioned(wo_kernel) return w0_kernel, w1_kernel, wo_kernel - @nn.compact - def __call__( - self, inputs: ctypes.Array - ) -> tuple[ctypes.Array, Optional[ctypes.Array]]: + def __call__(self, inputs: ctypes.Array) -> tuple[ctypes.Array, Optional[ctypes.Array]]: cfg = self.config inputs = inputs.astype(cfg.dtype) - gate_logits, pre_bias_logits = GateLogit( - self.num_experts, - model_name=cfg.model_name, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - quant=self.quant, - kernel_init=self.kernel_init, - kernel_axes=self.kernel_axes, - name="gate", - use_bias=cfg.routed_bias, - score_func=cfg.routed_score_func, - matmul_precision=cfg.matmul_precision, - )(inputs) + gate_logits, pre_bias_logits = self.gate(inputs) + + w0_kernel = jnp.asarray(self.wi_0.value, self.dtype) + w1_kernel = jnp.asarray(self.wi_1.value, self.dtype) + wo_kernel = jnp.asarray(self.wo.value, self.dtype) - w0_kernel, w1_kernel, wo_kernel = self.generate_kernels(cfg.num_experts, cfg.emb_dim, self.intermediate_dim) if cfg.sparse_matmul: if quantizations.in_serve_mode(self.quant): w0_kernel, w1_kernel, wo_kernel = self.retrieve_quantized_weight( @@ -1451,59 +1490,74 @@ def __call__( return self.dense_matmul(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel) -class RoutedAndSharedMoE(nn.Module): - """Implements a block which combines shared and routed experts. +class RoutedAndSharedMoE(nnx.Module): + """Implements a block which combines shared and routed experts.""" - Attributes: - config: Model configs. - mesh: device mesh. - kernel_init: Kernel function, passed to the dense layers. - kernel_axes: Tuple with axes to apply kernel function. - weight_dtype: Type for the weights. - dtype: Type for the dense layer. - quant: Optional quantization config, no quantization if None. - """ - - config: ctypes.Config - mesh: jax.sharding.Mesh - kernel_init: attentions.NdInitializer - kernel_axes: Tuple[Optional[str], ...] - weight_dtype: ctypes.DType = jnp.float32 - dtype: ctypes.DType = jnp.float32 - quant: Optional[quantizations.AqtQuantization] = None - - @nn.compact - def __call__(self, inputs: ctypes.Array) -> ctypes.Array: - cfg = self.config - # NOTE: the naming mismatch here is to ensure reverse compatibility with - # existing checkpoints. The `name` represents the weight name in - # JAX/checkpoints and so the class name is just for readability. - routed_experts, _ = RoutedMoE( - name="MoeBlock_0", - config=cfg, - num_experts=cfg.num_experts, - num_experts_per_tok=cfg.num_experts_per_tok, + def __init__( + self, + config: ctypes.Config, + mesh: jax.sharding.Mesh, + kernel_init: NdInitializer, + kernel_axes: Tuple[Optional[str], ...], + rngs: nnx.Rngs, + weight_dtype: ctypes.DType = jnp.float32, + dtype: ctypes.DType = jnp.float32, + quant: Optional[quantizations.AqtQuantization] = None, + ): + """nitializes the RoutedAndSharedMoE module. + + Attributes: + config: The main config setting. + mesh: Mesh, device mesh. + kernel_init: The initializer function for the kernel weight matrix. + kernel_axes: A tuple of logical axis names for partitioning the kernel. + rngs: An `nnx.Rngs` object used for initializing parameters. + weight_dtype: The data type of the kernel weights. + dtype: The data type for the computation. + quant: The quantization configuration. If None, no quantization is applied. + """ + self.config = config + self.mesh = mesh + self.kernel_init = kernel_init + self.kernel_axes = kernel_axes + self.weight_dtype = weight_dtype + self.dtype = dtype + self.quant = quant + self.rngs = rngs + self.routed_moe = RoutedMoE( + config=self.config, + num_experts=self.config.num_experts, + num_experts_per_tok=self.config.num_experts_per_tok, mesh=self.mesh, - kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_init=nd_dense_init(1.0, "fan_in", "truncated_normal"), kernel_axes=("embed", None), - intermediate_dim=cfg.moe_mlp_dim, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, + intermediate_dim=self.config.moe_mlp_dim, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, quant=self.quant, - )(inputs) - - shared_experts = linears.mlp_block( - in_features=inputs.shape[-1], - intermediate_dim=cfg.shared_experts * cfg.moe_mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="shared_experts", - config=cfg, + rngs=self.rngs, + ) + + self.shared_moe = linears.MlpBlock( + in_features=self.config.emb_dim, + intermediate_dim=self.config.shared_experts * self.config.moe_mlp_dim, + activations=self.config.mlp_activations, + intermediate_dropout_rate=self.config.dropout_rate, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + config=self.config, quant=self.quant, - )(inputs) - logical_axis_names = ("activation_batch", "activation_length", "activation_embed") - shared_experts = nn.with_logical_constraint(shared_experts, logical_axis_names) + rngs=self.rngs, + ) + # NOTE: the naming mismatch here is to ensure reverse compatibility with + # existing checkpoints. + self.routed_moe_ckpt_name = "MoeBlock_0" + setattr(self, self.routed_moe_ckpt_name, self.routed_moe) + self.shared_moe_ckpt_name = "shared_experts" + setattr(self, self.shared_moe_ckpt_name, self.shared_moe) + + def __call__(self, inputs: ctypes.Array) -> ctypes.Array: + routed_experts, _ = getattr(self, self.routed_moe_ckpt_name)(inputs) + shared_experts = getattr(self, self.shared_moe_ckpt_name)(inputs) return routed_experts + shared_experts diff --git a/MaxText/layers/moe_linen.py b/MaxText/layers/moe_linen.py new file mode 100644 index 000000000..dab566b98 --- /dev/null +++ b/MaxText/layers/moe_linen.py @@ -0,0 +1,128 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MoE related Linen Layers.""" + +from typing import Iterable, Optional, Tuple, Union +import jax +import jax.numpy as jnp +from MaxText import common_types as ctypes +from MaxText.layers import moe +from MaxText.layers import linears +from MaxText.layers import quantizations +from MaxText.layers import nnx_wrappers +from MaxText.layers.initializers import NdInitializer, nd_dense_init, variable_to_logically_partitioned + + +def get_gate_logit( + inputs_shape: tuple[int, ...], + out_features_shape: Union[Iterable[int], int], + model_name: str, + axis: Union[Iterable[int], int] = -1, + weight_dtype: ctypes.DType = jnp.float32, + dtype: ctypes.DType = jnp.float32, + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes: Tuple[Optional[str], ...] = (), + use_bias: bool = False, + score_func: str = "", + quant: Optional[quantizations.AqtQuantization] = None, + matmul_precision: str = "default", + name: Optional[str] = None, +): + """Creates a GateLogit Linen module.""" + + axis = linears.canonicalize_tuple(axis) + in_features_shape = tuple(inputs_shape[ax] for ax in linears.normalize_axes(axis, len(inputs_shape))) + + module = nnx_wrappers.to_linen( + moe.GateLogit, + in_features_shape=in_features_shape, + out_features_shape=out_features_shape, + model_name=model_name, + axis=axis, + weight_dtype=weight_dtype, + dtype=dtype, + kernel_init=kernel_init, + kernel_axes=kernel_axes, + use_bias=use_bias, + score_func=score_func, + quant=quant, + matmul_precision=matmul_precision, + name=name, + metadata_fn=variable_to_logically_partitioned, + abstract_init=False, + ) + return module + + +def get_routed_moe( + config: ctypes.Config, + num_experts: int, + num_experts_per_tok: int, + mesh: jax.sharding.Mesh, + kernel_init: NdInitializer, + kernel_axes: Tuple[Optional[str], ...], + intermediate_dim: int = 2048, + weight_dtype: ctypes.DType = jnp.float32, + dtype: ctypes.DType = jnp.float32, + quant: Optional[quantizations.AqtQuantization] = None, + name: Optional[str] = None, +): + """Creates a RoutedMoE Linen module.""" + + module = nnx_wrappers.to_linen( + moe.RoutedMoE, + config=config, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + mesh=mesh, + kernel_init=kernel_init, + kernel_axes=kernel_axes, + intermediate_dim=intermediate_dim, + weight_dtype=weight_dtype, + dtype=dtype, + quant=quant, + name=name, + metadata_fn=variable_to_logically_partitioned, + abstract_init=False, + ) + return module + + +def get_routed_and_shared_moe( + config: ctypes.Config, + mesh: jax.sharding.Mesh, + kernel_init: NdInitializer, + kernel_axes: Tuple[Optional[str], ...], + weight_dtype: ctypes.DType = jnp.float32, + dtype: ctypes.DType = jnp.float32, + quant: Optional[quantizations.AqtQuantization] = None, + name: Optional[str] = None, +): + """Creates a RoutedAndSharedMoE Linen module.""" + + module = nnx_wrappers.to_linen( + moe.RoutedAndSharedMoE, + config=config, + mesh=mesh, + kernel_init=kernel_init, + kernel_axes=kernel_axes, + weight_dtype=weight_dtype, + dtype=dtype, + quant=quant, + name=name, + metadata_fn=variable_to_logically_partitioned, + abstract_init=False, + ) + return module diff --git a/MaxText/layers/qwen3.py b/MaxText/layers/qwen3.py index 9bb6b502c..fe47e75e1 100644 --- a/MaxText/layers/qwen3.py +++ b/MaxText/layers/qwen3.py @@ -30,7 +30,7 @@ from MaxText.layers import attentions from MaxText.layers import initializers from MaxText.layers import linears -from MaxText.layers import moe +from MaxText.layers import moe_linen from MaxText.layers import quantizations from MaxText.layers.normalizations import rms_norm from MaxText.layers.quantizations import AqtQuantization as Quant @@ -136,7 +136,7 @@ def __call__( quant=self.quant, )(mlp_input, deterministic=deterministic) else: # Mixture of Experts MLP -- not supported / tested in MaxText - mlp_output, _ = moe.RoutedMoE( + mlp_output, _ = moe_linen.get_routed_moe( config=cfg, num_experts=cfg.num_experts, num_experts_per_tok=cfg.num_experts_per_tok, diff --git a/MaxText/tests/moe_test.py b/MaxText/tests/moe_test.py index 6c80baa21..e7dc56d01 100644 --- a/MaxText/tests/moe_test.py +++ b/MaxText/tests/moe_test.py @@ -24,16 +24,18 @@ from jax.sharding import Mesh import flax.linen as nn +from flax import nnx from flax.linen import partitioning as nn_partitioning from MaxText import maxtext_utils from MaxText import pyconfig from MaxText.common_types import Config, DType from MaxText.globals import PKG_DIR -from MaxText.layers.linears import mlp_block -from MaxText.layers import moe -from MaxText.layers.initializers import NdInitializer, nd_dense_init +from MaxText.layers import linears +from MaxText.layers import moe, moe_linen +from MaxText.layers.initializers import NdInitializer, nd_dense_init, variable_to_logically_partitioned from MaxText.layers.quantizations import Fp8Quantization +from MaxText.layers import nnx_wrappers class TokenDroppingTest(unittest.TestCase): @@ -52,10 +54,9 @@ def setUp(self): per_device_batch_size=1, capacity_factor=2, ) - self.rng = jax.random.PRNGKey(42) + self.rngs = nnx.Rngs(params=0) devices_array = maxtext_utils.create_device_mesh(self.cfg) self.model = moe.RoutedMoE( - name="MoeBlock", config=self.cfg, num_experts=self.cfg.num_experts, num_experts_per_tok=self.cfg.num_experts_per_tok, @@ -63,6 +64,7 @@ def setUp(self): kernel_init=nd_dense_init(1.0, "fan_in", "truncated_normal"), kernel_axes=("embed", "mlp"), dtype=self.cfg.dtype, + rngs=self.rngs, ) def test_generate_masks(self): @@ -159,13 +161,14 @@ def test_generate_masks(self): self.assertTrue((expected_dispatch_mask == actual_dispatch_mask).all()) self.assertTrue(jax.numpy.allclose(expected_combine_mask, actual_combine_mask, rtol=1e-02, atol=1e-02)) + class MlpBlockTest(unittest.TestCase): def setUp(self): super().setUp() self.config = pyconfig.initialize( [None, os.path.join(PKG_DIR, "configs", "base.yml")], - run_name="token_dropping_test", + run_name="mlp_block_init_test", enable_checkpointing=False, model_name="mixtral-8x7b", dtype="bfloat16", @@ -177,7 +180,7 @@ def setUp(self): ) self.rng = jax.random.PRNGKey(42) quant = Fp8Quantization() - self.model = mlp_block( + self.model = linears.mlp_block( config=self.config, in_features=2, intermediate_dim=2, @@ -187,13 +190,14 @@ def setUp(self): weight_dtype=jnp.bfloat16, name="mlp", quant=quant, - use_bias=True + use_bias=True, ) def test_init(self): x = jnp.array([1.0, 2.0]) self.model.init({"params": self.rng, "dropout": self.rng}, x) + class DeepSeekRoutingTest(unittest.TestCase): def setUp(self): @@ -213,10 +217,9 @@ def setUp(self): num_experts_per_tok=4, sparse_matmul=True, ) - self.rng = jax.random.PRNGKey(42) + self.rngs = nnx.Rngs(params=0) devices_array = maxtext_utils.create_device_mesh(self.cfg) self.model = moe.RoutedMoE( - name="MoeBlock", config=self.cfg, num_experts=self.cfg.num_experts, num_experts_per_tok=self.cfg.num_experts_per_tok, @@ -224,6 +227,7 @@ def setUp(self): kernel_init=nd_dense_init(1.0, "fan_in", "truncated_normal"), kernel_axes=("embed", "mlp"), dtype=self.cfg.dtype, + rngs=self.rngs, ) def test_deepseek_routing(self): @@ -262,30 +266,56 @@ def test_deepseek_routing(self): ) -class MoeLoopBlock(nn.Module): +class MoeLoopBlock(nnx.Module): """Reference implementation from https://github.com/mistralai/mistral-inference. This is not included anymore in our repo, due to a limitation of for-loop implementation in sharding. """ - config: Config - num_experts: int - num_experts_per_tok: int - kernel_init: NdInitializer - kernel_axes: Tuple[str, ...] - weight_dtype: DType = jnp.float32 - dtype: DType = jnp.bfloat16 - - @nn.compact - def __call__(self, inputs, deterministic: bool = False): - gate_logits = moe.GateLogit( - self.num_experts, - self.config.model_name, + def __init__( + self, + config: Config, + inputs_shape: tuple[int, ...], + num_experts: int, + num_experts_per_tok: int, + kernel_init: NdInitializer, + kernel_axes: Tuple[str, ...], + *, + rngs: nnx.Rngs, + weight_dtype: DType = jnp.float32, + dtype: DType = jnp.bfloat16, + ): + self.config = config + self.inputs_shape = inputs_shape + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.kernel_init = kernel_init + self.kernel_axes = kernel_axes + self.weight_dtype = weight_dtype + self.dtype = dtype + self.gate = moe.GateLogit( + in_features_shape=self.inputs_shape[-1], + out_features_shape=self.num_experts, + model_name=self.config.model_name, dtype=self.dtype, kernel_init=self.kernel_init, kernel_axes=self.kernel_axes, - name="gate", - )(inputs)[0] + rngs=rngs, + ) + for k in range(self.num_experts): + expert_module = linears.MlpBlock( + config=self.config, + in_features=self.inputs_shape[-1], + intermediate_dim=self.config.mlp_dim, + activations=["silu", "linear"], + intermediate_dropout_rate=self.config.dropout_rate, + dtype=dtype, + weight_dtype=weight_dtype, + rngs=rngs, + ) + setattr(self, f"mlp_{k}", expert_module) + def __call__(self, inputs, deterministic: bool = False): + gate_logits = self.gate(inputs)[0] weights, selected_experts = jax.lax.top_k(gate_logits, self.num_experts_per_tok) weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1).astype(self.weight_dtype) mlp_lnx = jnp.zeros_like(inputs) @@ -293,17 +323,8 @@ def __call__(self, inputs, deterministic: bool = False): for k in range(self.num_experts): weights_exp = jnp.sum(jnp.multiply(selected_experts == k, weights), axis=-1) - mlp_lnx_exp = mlp_block( - config=self.config, - in_features=inputs.shape[-1], - intermediate_dim=self.config.mlp_dim, - activations=["silu", "linear"], - intermediate_dropout_rate=self.config.dropout_rate, - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name=f"mlp_{k}", - )(inputs, deterministic=deterministic) - + getattr(self, f"mlp_{k}") + mlp_lnx_exp = getattr(self, f"mlp_{k}")(inputs, deterministic=deterministic) mlp_lnx_exp = nn.with_logical_constraint(mlp_lnx_exp, ("activation_batch", "activation_length", "activation_embed")) mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp mlp_lnx += mlp_lnx_exp @@ -311,18 +332,46 @@ def __call__(self, inputs, deterministic: bool = False): return mlp_lnx +def get_moe_loop( + config: Config, + inputs_shape: tuple[int, ...], + num_experts: int, + num_experts_per_tok: int, + kernel_init: NdInitializer, + kernel_axes: Tuple[str, ...], + weight_dtype: DType = jnp.float32, + dtype: DType = jnp.bfloat16, +): + """Creates a MoeLoopBlock Linen module.""" + module = nnx_wrappers.to_linen( + MoeLoopBlock, + config=config, + inputs_shape=inputs_shape, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + kernel_init=kernel_init, + kernel_axes=kernel_axes, + weight_dtype=weight_dtype, + dtype=dtype, + metadata_fn=variable_to_logically_partitioned, + ) + return module + + class RoutedMoeTest(unittest.TestCase): """Routed Mixture of Experts test.""" def get_expected_output(self, rng, hidden_states, cfg): """Retrieve expected output from Routed Mixture of Experts.""" - model = MoeLoopBlock( + model = get_moe_loop( config=cfg, + inputs_shape=hidden_states.shape, num_experts=cfg.num_experts, num_experts_per_tok=cfg.num_experts_per_tok, kernel_init=nd_dense_init(1.0, "fan_in", "truncated_normal"), kernel_axes=("embed", "mlp"), dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, ) variables = model.init( rng, jax.random.normal(rng, (int(cfg.per_device_batch_size), cfg.max_target_length, cfg.base_emb_dim)) @@ -333,7 +382,7 @@ def get_expected_output(self, rng, hidden_states, cfg): def get_moe_output(self, variables, hidden_states, cfg, mesh): """retrieve expected output from MoE""" - model = moe.RoutedMoE( + model = moe_linen.get_routed_moe( name="MoeBlock", config=cfg, num_experts=cfg.num_experts,