Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba + Tensor Parallel Support #1184

Merged
merged 9 commits into from
Mar 15, 2024
8 changes: 4 additions & 4 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
ParallelLinear,
)
from megatron.model.gmlp import GMLPBlock
from megatron.model.mamba import MambaResidualLayerPipe
from megatron.model.mamba import ParallelMambaResidualLayerPipe
from megatron.model.word_embeddings import EmbeddingPipe, SoftEmbedding

# Pipeline parallelism
Expand Down Expand Up @@ -138,7 +138,7 @@ def __init__(
checkpointable_layers=[
"GMLPBlock",
"ParallelTransformerLayerPipe",
"MambaResidualLayerPipe",
"ParallelMambaResidualLayerPipe",
],
)

Expand Down Expand Up @@ -174,7 +174,7 @@ def insert_layers(
checkpointable_layers=[
"GMLPBlock",
"ParallelTransformerLayerPipe",
"MambaResidualLayerPipe",
"ParallelMambaResidualLayerPipe",
],
)

Expand Down Expand Up @@ -254,7 +254,7 @@ def init_specs(self):
elif layer_type in ["mamba"]:
self.specs.append(
LayerSpec(
MambaResidualLayerPipe,
ParallelMambaResidualLayerPipe,
neox_args=self.neox_args,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
Expand Down
5 changes: 4 additions & 1 deletion megatron/model/mamba/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .mamba import MambaResidualLayer, MambaResidualLayerPipe
from .mamba import (
ParallelMambaResidualLayer,
ParallelMambaResidualLayerPipe,
)
106 changes: 64 additions & 42 deletions megatron/model/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
pass

from megatron.model.norms import get_norm
from megatron import mpu


# Mamba layer, without parallelism.
class MambaBlock(nn.Module):
# Mamba sublayer, with tensor parallelism
class ParallelMambaBlock(nn.Module):
def __init__(
self,
neox_args,
Expand Down Expand Up @@ -58,50 +58,64 @@ def __init__(
self.dt_min, self.dt_max, self.dt_init_floor = 0.001, 0.1, 1e-4
assert self.dt_init in ["constant", "random"]

# TP-specific setup
world_size = mpu.get_model_parallel_world_size()
self.d_inner_per_rank = mpu.divide(self.d_inner, world_size)

if neox_args.mamba_inner_func_fusion and world_size > 1:
# as with gpt-j residual, we must manually reduce output from final proj
# across TP ranks, since it is not done by fused mamba_inner_fn .
self.reduce = mpu.mappings.reduce_from_model_parallel_region

# up-projection.
self.in_proj = nn.Linear(
self.d_model,
self.d_inner * 2,
self.in_proj = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=self.d_model,
output_size=self.d_inner * 2,
gather_output=False,
init_method=init_method,
skip_bias_add=not neox_args.mamba_use_bias_in_linears,
bias=neox_args.mamba_use_bias_in_linears,
**factory_kwargs,
)
init_method(self.in_proj.weight)

# convolution.
# convolution (parallelized across d_inner)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
in_channels=self.d_inner_per_rank,
out_channels=self.d_inner_per_rank,
bias=neox_args.mamba_use_bias_in_conv,
kernel_size=self.d_conv,
groups=self.d_inner,
groups=self.d_inner_per_rank,
padding=self.d_conv - 1,
**factory_kwargs,
)
# Conv bias sometimes in 32-bit erroneously, when holding other parameters in fp32.
# Uncertain why
self.conv1d.to(self.precision)

self.act_fn = F.silu # we do not allow for
self.act_fn = F.silu # we do not allow for other activation fns

# x_proj corresponds to s_B(x), s_C(x), s_Delta(x)
# in https://arxiv.org/pdf/2312.00752.pdf Algorithm 2
# (computes data-dependent B, C, Delta/dt)
self.x_proj = nn.Linear(
self.d_inner,
self.dt_rank + self.d_state * 2,
self.x_proj = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=self.d_inner,
output_size=self.dt_rank + self.d_state * 2,
input_is_parallel=True,
init_method=init_method,
skip_bias_add=not neox_args.mamba_use_bias_in_linears,
parallel_output=True,
bias=neox_args.mamba_use_bias_in_linears,
**factory_kwargs,
)
init_method(self.x_proj.weight)

# up-project dt / Delta from dt_rank to d_inner
# dt_proj 's bias is a special case and I believe we should keep it turned on -- Alg. 2 in the Mamba paper (https://arxiv.org/abs/2312.00752)
# dt_proj 's bias is a special case and should be kept always turned on -- Alg. 2 in the Mamba paper (https://arxiv.org/abs/2312.00752)
# defines Delta as Delta = Tau_{Delta}(Parameter + s_{Delta}(x)) where s_{Delta}(x) = Broadcast_{D}(Linear_{1}(x))
# or as they further explain in section 3.6 can be also s_{Delta}(x) = Linear_{D}(Linear_{R}(x)) where Linear_R
# is the delta portion of x_proj and Linear_D is the dt_proj weight. Then, the Parameter term from Alg. 2 can
# be viewed as the bias term in dt_proj, with a special initialization from https://arxiv.org/abs/2206.12037
self.dt_proj = nn.Linear(
self.dt_rank, self.d_inner, bias=True, **factory_kwargs
self.dt_rank, self.d_inner_per_rank, bias=True, **factory_kwargs
)

# special init for dt_proj
Expand All @@ -115,7 +129,7 @@ def __init__(

# more dt_proj init stuff. copied from https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/modules/mamba_simple.py#L91-L101
dt = torch.exp(
torch.rand(self.d_inner, **factory_kwargs)
torch.rand(self.d_inner_per_rank, **factory_kwargs)
* (math.log(self.dt_max) - math.log(self.dt_min))
+ math.log(self.dt_min)
).clamp(min=self.dt_init_floor)
Expand All @@ -133,7 +147,7 @@ def __init__(
device=torch.cuda.current_device(),
),
"n -> d n",
d=self.d_inner,
d=self.d_inner_per_rank,
).contiguous()
A_log = torch.log(A).to(
torch.float32
Expand All @@ -150,7 +164,9 @@ def __init__(
# D parameter
self.D = nn.Parameter(
torch.ones(
self.d_inner, device=torch.cuda.current_device(), dtype=torch.float32
self.d_inner_per_rank,
device=torch.cuda.current_device(),
dtype=torch.float32,
)
).to(
torch.float32
Expand All @@ -163,14 +179,20 @@ def __init__(
if self.neox_args.mamba_selective_fp32_params:
self.D._deepspeed_no_cast = True

# out down-projection
self.out_proj = nn.Linear(
self.d_inner,
self.d_model,
# out down-projection.
# use "single_residual_scaled_normal"
# for output_layer_init_method
# to perform gpt-2 style scaled init as done in Mamba paper.
self.out_proj = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=self.d_inner,
output_size=self.d_model,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=not neox_args.mamba_use_bias_in_linears,
bias=neox_args.mamba_use_bias_in_linears,
**factory_kwargs,
parallel_output=False,
)
output_layer_init_method(self.out_proj.weight)

def selective_scan(
self,
Expand Down Expand Up @@ -224,14 +246,8 @@ def forward(self, hidden_states):
seqlen, batch, dim = hidden_states.shape

# first up: perform in_proj
xz = einops.rearrange(
self.in_proj.weight @ einops.rearrange(hidden_states, "l b d -> d (b l)"),
"d (b l) -> b d l",
l=seqlen,
)

if self.in_proj.bias is not None:
xz = xz + einops.rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
xz, _ = self.in_proj(hidden_states)
xz = einops.rearrange(xz, "l b d -> b d l")

A = -torch.exp(self.A_log.float()) # (d_inner, d_state)

Expand Down Expand Up @@ -262,6 +278,12 @@ def forward(self, hidden_states):
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
)
if getattr(self, "reduce", None):
# manually reduce after mamba_inner_fn
# to collect outputs from different TP ranks.
# handled by running self.out_proj(y) below
# so only needed here.
out = self.reduce(out)

out = einops.rearrange(out, "b l h -> l b h")

Expand Down Expand Up @@ -292,7 +314,7 @@ def forward(self, hidden_states):
# ==============

# project: perform s_B, s_C, s_Delta projections
x_dbl = self.x_proj(einops.rearrange(x, "b d l -> (b l) d"))
x_dbl, _ = self.x_proj(einops.rearrange(x, "b d l -> (b l) d"))
# split into component dt, B, C
dt, B, C = torch.split(
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
Expand Down Expand Up @@ -324,14 +346,14 @@ def forward(self, hidden_states):
# ===============
y = einops.rearrange(y, "b d l -> b l d")

out = self.out_proj(y)
out, _ = self.out_proj(y)

out = einops.rearrange(out, "b l h -> l b h")

return out


class MambaResidualLayer(nn.Module):
class ParallelMambaResidualLayer(nn.Module):
"""
Pre-norm Mamba Block with residual connection. No parallelism yet supported.
"""
Expand All @@ -352,7 +374,7 @@ def __init__(

self.norm = norm(neox_args.hidden_size, eps=eps)

self.mixer = MambaBlock(
self.mixer = ParallelMambaBlock(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
Expand All @@ -369,7 +391,7 @@ def forward(self, x, attention_mask=None, layer_past=None):
return hidden_states + residual


class MambaResidualLayerPipe(MambaResidualLayer):
class ParallelMambaResidualLayerPipe(ParallelMambaResidualLayer):
"""Extends MambaResidualLayer to forward attention_mask through the pipeline. DeepSpeed requires this."""

def forward(self, args):
Expand Down
3 changes: 0 additions & 3 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,9 +1061,6 @@ def calculate_derived(self):
not self.partition_activations
), "GMLP Blocks are not compatible with partition activations"
if "mamba" in self.attention_config:
assert (
not self.is_pipe_parallel and self.model_parallel_size == 1
), "Mamba not currently compatible with parallelism"
if isinstance(self.zero_stage, int):
assert self.zero_stage <= 2, "Zero stage 3 not compatible with Mamba"
assert (
Expand Down
35 changes: 11 additions & 24 deletions tools/ckpts/convert_neox_to_mamba_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,26 @@
"""
ARCH = {
"COLUMN_PARALLEL_LINEAR_KEYS": {
# these require concat across dim=0
"mixer.in_proj.weight": "mixer.in_proj.weight",
# "mixer.in_proj.bias": "mixer.in_proj.bias",
"mixer.A_log": "mixer.A_log",
"mixer.D": "mixer.D",
"mixer.conv1d.weight": "mixer.conv1d.weight",
"mixer.conv1d.bias": "mixer.conv1d.bias",
"mixer.dt_proj.weight": "mixer.dt_proj.weight",
"mixer.dt_proj.bias": "mixer.dt_proj.bias",
},
"ROW_PARALLEL_LINEAR_KEYS": {
# these require concat across dim=1
"mixer.out_proj.weight": "mixer.out_proj.weight",
"mixer.x_proj.weight": "mixer.x_proj.weight",
},
"ROW_PARALLEL_BIAS_KEYS": {
# these require summing across ranks
# "mixer.x_proj.bias": "mixer.x_proj.bias",
# "mixer.out_proj.bias": "mixer.out_proj.bias",
},
"NO_SHARD_KEYS": {
"mixer.A_log": "mixer.A_log",
"mixer.D": "mixer.D",
"mixer.x_proj.weight": "mixer.x_proj.weight",
"mixer.dt_proj.weight": "mixer.dt_proj.weight",
"mixer.dt_proj.bias": "mixer.dt_proj.bias",
"mixer.conv1d.weight": "mixer.conv1d.weight",
"mixer.conv1d.bias": "mixer.conv1d.bias",
},
"NORM_KEYS": {
"norm.scale": "norm.weight",
# "norm.bias": "norm.bias",
Expand Down Expand Up @@ -226,15 +228,6 @@ def convert(
)
)

# Average params which aren't sharded across ranks.
# they should be the same across ranks, so should be fine
for key, hf_key in ARCH["NO_SHARD_KEYS"].items():
state_dict[hf_key] = sum(
get_state(
loaded_tp_ranks, key, layer_idx=layer_i + 2, sequential=sequential
)
) / len(loaded_tp_ranks)

layer.load_state_dict(state_dict)

if not sequential:
Expand Down Expand Up @@ -320,12 +313,6 @@ def main(input_args=None, overwrite_values=None):
action="store_true",
help="Whether to skip saving the tokenizer alongside a model.",
)
parser.add_argument(
"--architecture",
type=str,
default="neox",
help="What HF model class type to export into.",
)
args = parser.parse_args(input_args)

# validate arguments
Expand Down
Loading