Skip to content

Commit

Permalink
Mamba + Tensor Parallel Support (#1184)
Browse files Browse the repository at this point in the history
* TP works!

* merge TP mamba changes with most current MambaLayer

* cleanup TP, confirmed working still

* make shapes with TP>1 work with conversion

* tested and PP works, so no need for assert blocking it in arguments

* update comment

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

---------

Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Quentin Anthony <qganthony@yahoo.com>
  • Loading branch information
3 people authored Mar 15, 2024
1 parent 03186de commit 277141e
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 75 deletions.
2 changes: 1 addition & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = fdac107
Default = f70c54d

current git hash of repository

Expand Down
8 changes: 4 additions & 4 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
ParallelLinear,
)
from megatron.model.gmlp import GMLPBlock
from megatron.model.mamba import MambaResidualLayerPipe
from megatron.model.mamba import ParallelMambaResidualLayerPipe
from megatron.model.word_embeddings import EmbeddingPipe, SoftEmbedding

# Pipeline parallelism
Expand Down Expand Up @@ -138,7 +138,7 @@ def __init__(
checkpointable_layers=[
"GMLPBlock",
"ParallelTransformerLayerPipe",
"MambaResidualLayerPipe",
"ParallelMambaResidualLayerPipe",
],
)

Expand Down Expand Up @@ -174,7 +174,7 @@ def insert_layers(
checkpointable_layers=[
"GMLPBlock",
"ParallelTransformerLayerPipe",
"MambaResidualLayerPipe",
"ParallelMambaResidualLayerPipe",
],
)

Expand Down Expand Up @@ -254,7 +254,7 @@ def init_specs(self):
elif layer_type in ["mamba"]:
self.specs.append(
LayerSpec(
MambaResidualLayerPipe,
ParallelMambaResidualLayerPipe,
neox_args=self.neox_args,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
Expand Down
5 changes: 4 additions & 1 deletion megatron/model/mamba/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .mamba import MambaResidualLayer, MambaResidualLayerPipe
from .mamba import (
ParallelMambaResidualLayer,
ParallelMambaResidualLayerPipe,
)
106 changes: 64 additions & 42 deletions megatron/model/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
pass

from megatron.model.norms import get_norm
from megatron import mpu


# Mamba layer, without parallelism.
class MambaBlock(nn.Module):
# Mamba sublayer, with tensor parallelism
class ParallelMambaBlock(nn.Module):
def __init__(
self,
neox_args,
Expand Down Expand Up @@ -58,50 +58,64 @@ def __init__(
self.dt_min, self.dt_max, self.dt_init_floor = 0.001, 0.1, 1e-4
assert self.dt_init in ["constant", "random"]

# TP-specific setup
world_size = mpu.get_model_parallel_world_size()
self.d_inner_per_rank = mpu.divide(self.d_inner, world_size)

if neox_args.mamba_inner_func_fusion and world_size > 1:
# as with gpt-j residual, we must manually reduce output from final proj
# across TP ranks, since it is not done by fused mamba_inner_fn .
self.reduce = mpu.mappings.reduce_from_model_parallel_region

# up-projection.
self.in_proj = nn.Linear(
self.d_model,
self.d_inner * 2,
self.in_proj = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=self.d_model,
output_size=self.d_inner * 2,
gather_output=False,
init_method=init_method,
skip_bias_add=not neox_args.mamba_use_bias_in_linears,
bias=neox_args.mamba_use_bias_in_linears,
**factory_kwargs,
)
init_method(self.in_proj.weight)

# convolution.
# convolution (parallelized across d_inner)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
in_channels=self.d_inner_per_rank,
out_channels=self.d_inner_per_rank,
bias=neox_args.mamba_use_bias_in_conv,
kernel_size=self.d_conv,
groups=self.d_inner,
groups=self.d_inner_per_rank,
padding=self.d_conv - 1,
**factory_kwargs,
)
# Conv bias sometimes in 32-bit erroneously, when holding other parameters in fp32.
# Uncertain why
self.conv1d.to(self.precision)

self.act_fn = F.silu # we do not allow for
self.act_fn = F.silu # we do not allow for other activation fns

# x_proj corresponds to s_B(x), s_C(x), s_Delta(x)
# in https://arxiv.org/pdf/2312.00752.pdf Algorithm 2
# (computes data-dependent B, C, Delta/dt)
self.x_proj = nn.Linear(
self.d_inner,
self.dt_rank + self.d_state * 2,
self.x_proj = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=self.d_inner,
output_size=self.dt_rank + self.d_state * 2,
input_is_parallel=True,
init_method=init_method,
skip_bias_add=not neox_args.mamba_use_bias_in_linears,
parallel_output=True,
bias=neox_args.mamba_use_bias_in_linears,
**factory_kwargs,
)
init_method(self.x_proj.weight)

# up-project dt / Delta from dt_rank to d_inner
# dt_proj 's bias is a special case and I believe we should keep it turned on -- Alg. 2 in the Mamba paper (https://arxiv.org/abs/2312.00752)
# dt_proj 's bias is a special case and should be kept always turned on -- Alg. 2 in the Mamba paper (https://arxiv.org/abs/2312.00752)
# defines Delta as Delta = Tau_{Delta}(Parameter + s_{Delta}(x)) where s_{Delta}(x) = Broadcast_{D}(Linear_{1}(x))
# or as they further explain in section 3.6 can be also s_{Delta}(x) = Linear_{D}(Linear_{R}(x)) where Linear_R
# is the delta portion of x_proj and Linear_D is the dt_proj weight. Then, the Parameter term from Alg. 2 can
# be viewed as the bias term in dt_proj, with a special initialization from https://arxiv.org/abs/2206.12037
self.dt_proj = nn.Linear(
self.dt_rank, self.d_inner, bias=True, **factory_kwargs
self.dt_rank, self.d_inner_per_rank, bias=True, **factory_kwargs
)

# special init for dt_proj
Expand All @@ -115,7 +129,7 @@ def __init__(

# more dt_proj init stuff. copied from https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/modules/mamba_simple.py#L91-L101
dt = torch.exp(
torch.rand(self.d_inner, **factory_kwargs)
torch.rand(self.d_inner_per_rank, **factory_kwargs)
* (math.log(self.dt_max) - math.log(self.dt_min))
+ math.log(self.dt_min)
).clamp(min=self.dt_init_floor)
Expand All @@ -133,7 +147,7 @@ def __init__(
device=torch.cuda.current_device(),
),
"n -> d n",
d=self.d_inner,
d=self.d_inner_per_rank,
).contiguous()
A_log = torch.log(A).to(
torch.float32
Expand All @@ -150,7 +164,9 @@ def __init__(
# D parameter
self.D = nn.Parameter(
torch.ones(
self.d_inner, device=torch.cuda.current_device(), dtype=torch.float32
self.d_inner_per_rank,
device=torch.cuda.current_device(),
dtype=torch.float32,
)
).to(
torch.float32
Expand All @@ -163,14 +179,20 @@ def __init__(
if self.neox_args.mamba_selective_fp32_params:
self.D._deepspeed_no_cast = True

# out down-projection
self.out_proj = nn.Linear(
self.d_inner,
self.d_model,
# out down-projection.
# use "single_residual_scaled_normal"
# for output_layer_init_method
# to perform gpt-2 style scaled init as done in Mamba paper.
self.out_proj = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=self.d_inner,
output_size=self.d_model,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=not neox_args.mamba_use_bias_in_linears,
bias=neox_args.mamba_use_bias_in_linears,
**factory_kwargs,
parallel_output=False,
)
output_layer_init_method(self.out_proj.weight)

def selective_scan(
self,
Expand Down Expand Up @@ -224,14 +246,8 @@ def forward(self, hidden_states):
seqlen, batch, dim = hidden_states.shape

# first up: perform in_proj
xz = einops.rearrange(
self.in_proj.weight @ einops.rearrange(hidden_states, "l b d -> d (b l)"),
"d (b l) -> b d l",
l=seqlen,
)

if self.in_proj.bias is not None:
xz = xz + einops.rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
xz, _ = self.in_proj(hidden_states)
xz = einops.rearrange(xz, "l b d -> b d l")

A = -torch.exp(self.A_log.float()) # (d_inner, d_state)

Expand Down Expand Up @@ -262,6 +278,12 @@ def forward(self, hidden_states):
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
)
if getattr(self, "reduce", None):
# manually reduce after mamba_inner_fn
# to collect outputs from different TP ranks.
# handled by running self.out_proj(y) below
# so only needed here.
out = self.reduce(out)

out = einops.rearrange(out, "b l h -> l b h")

Expand Down Expand Up @@ -292,7 +314,7 @@ def forward(self, hidden_states):
# ==============

# project: perform s_B, s_C, s_Delta projections
x_dbl = self.x_proj(einops.rearrange(x, "b d l -> (b l) d"))
x_dbl, _ = self.x_proj(einops.rearrange(x, "b d l -> (b l) d"))
# split into component dt, B, C
dt, B, C = torch.split(
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
Expand Down Expand Up @@ -324,14 +346,14 @@ def forward(self, hidden_states):
# ===============
y = einops.rearrange(y, "b d l -> b l d")

out = self.out_proj(y)
out, _ = self.out_proj(y)

out = einops.rearrange(out, "b l h -> l b h")

return out


class MambaResidualLayer(nn.Module):
class ParallelMambaResidualLayer(nn.Module):
"""
Pre-norm Mamba Block with residual connection. No parallelism yet supported.
"""
Expand All @@ -352,7 +374,7 @@ def __init__(

self.norm = norm(neox_args.hidden_size, eps=eps)

self.mixer = MambaBlock(
self.mixer = ParallelMambaBlock(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
Expand All @@ -369,7 +391,7 @@ def forward(self, x, attention_mask=None, layer_past=None):
return hidden_states + residual


class MambaResidualLayerPipe(MambaResidualLayer):
class ParallelMambaResidualLayerPipe(ParallelMambaResidualLayer):
"""Extends MambaResidualLayer to forward attention_mask through the pipeline. DeepSpeed requires this."""

def forward(self, args):
Expand Down
3 changes: 0 additions & 3 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,9 +1061,6 @@ def calculate_derived(self):
not self.partition_activations
), "GMLP Blocks are not compatible with partition activations"
if "mamba" in self.attention_config:
assert (
not self.is_pipe_parallel and self.model_parallel_size == 1
), "Mamba not currently compatible with parallelism"
if isinstance(self.zero_stage, int):
assert self.zero_stage <= 2, "Zero stage 3 not compatible with Mamba"
assert (
Expand Down
35 changes: 11 additions & 24 deletions tools/ckpts/convert_neox_to_mamba_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,26 @@
"""
ARCH = {
"COLUMN_PARALLEL_LINEAR_KEYS": {
# these require concat across dim=0
"mixer.in_proj.weight": "mixer.in_proj.weight",
# "mixer.in_proj.bias": "mixer.in_proj.bias",
"mixer.A_log": "mixer.A_log",
"mixer.D": "mixer.D",
"mixer.conv1d.weight": "mixer.conv1d.weight",
"mixer.conv1d.bias": "mixer.conv1d.bias",
"mixer.dt_proj.weight": "mixer.dt_proj.weight",
"mixer.dt_proj.bias": "mixer.dt_proj.bias",
},
"ROW_PARALLEL_LINEAR_KEYS": {
# these require concat across dim=1
"mixer.out_proj.weight": "mixer.out_proj.weight",
"mixer.x_proj.weight": "mixer.x_proj.weight",
},
"ROW_PARALLEL_BIAS_KEYS": {
# these require summing across ranks
# "mixer.x_proj.bias": "mixer.x_proj.bias",
# "mixer.out_proj.bias": "mixer.out_proj.bias",
},
"NO_SHARD_KEYS": {
"mixer.A_log": "mixer.A_log",
"mixer.D": "mixer.D",
"mixer.x_proj.weight": "mixer.x_proj.weight",
"mixer.dt_proj.weight": "mixer.dt_proj.weight",
"mixer.dt_proj.bias": "mixer.dt_proj.bias",
"mixer.conv1d.weight": "mixer.conv1d.weight",
"mixer.conv1d.bias": "mixer.conv1d.bias",
},
"NORM_KEYS": {
"norm.scale": "norm.weight",
# "norm.bias": "norm.bias",
Expand Down Expand Up @@ -226,15 +228,6 @@ def convert(
)
)

# Average params which aren't sharded across ranks.
# they should be the same across ranks, so should be fine
for key, hf_key in ARCH["NO_SHARD_KEYS"].items():
state_dict[hf_key] = sum(
get_state(
loaded_tp_ranks, key, layer_idx=layer_i + 2, sequential=sequential
)
) / len(loaded_tp_ranks)

layer.load_state_dict(state_dict)

if not sequential:
Expand Down Expand Up @@ -320,12 +313,6 @@ def main(input_args=None, overwrite_values=None):
action="store_true",
help="Whether to skip saving the tokenizer alongside a model.",
)
parser.add_argument(
"--architecture",
type=str,
default="neox",
help="What HF model class type to export into.",
)
args = parser.parse_args(input_args)

# validate arguments
Expand Down

0 comments on commit 277141e

Please sign in to comment.