Skip to content

Cast A_log to float32 before exp in Mamba2Simple.forward#923

Open
Chessing234 wants to merge 1 commit intostate-spaces:mainfrom
Chessing234:fix/mamba2-simple-A-log-float-cast
Open

Cast A_log to float32 before exp in Mamba2Simple.forward#923
Chessing234 wants to merge 1 commit intostate-spaces:mainfrom
Chessing234:fix/mamba2-simple-A-log-float-cast

Conversation

@Chessing234
Copy link
Copy Markdown
Contributor

Bug

Mamba2Simple.forward in mamba_ssm/modules/mamba2_simple.py computes A directly from self.A_log without promoting to float32:

zxbcdt = self.in_proj(u)  # (B, L, d_in_proj)
A = -torch.exp(self.A_log)  # (nheads) or (d_inner, d_state)

A_log is registered in __init__ as

A_log = torch.log(A).to(dtype=dtype)
self.A_log = nn.Parameter(A_log)

so it lives in the model's configured dtype (commonly bf16 or fp16). The exp is therefore executed in low precision, and the quantised A is handed to the SSD kernels (mamba_split_conv1d_scan_combined / mamba_chunk_scan_combined).

Root cause

The two sibling modules — which share the exact same A_log storage convention — both upcast before exp:

  • mamba_ssm/modules/mamba2.py, forward (line 182): A = -torch.exp(self.A_log.float())
  • mamba_ssm/modules/mamba2.py, step (line 307): A = -torch.exp(self.A_log.float())
  • mamba_ssm/modules/mamba_simple.py, forward (line 143) and step (line 235): A = -torch.exp(self.A_log.float())

Given that Mamba2Simple is a trimmed-down version of Mamba2, the missing .float() here is an oversight when the simpler variant was factored out.

Fix

Add the .float() cast so exp runs in fp32, matching the reference behaviour:

A = -torch.exp(self.A_log.float())  # (nheads) or (d_inner, d_state)

Why the fix is correct

  • Single-line change, identical to the expression used in Mamba2.forward, Mamba2.step, and both Mamba paths — so Mamba2Simple now produces numerically consistent A values with the non-simple path under mixed-precision.
  • A_log._no_weight_decay = True and gradient flow are unaffected: .float() creates an upcast view in the forward graph; autograd still propagates gradients back into self.A_log.
  • Downstream kernels accept the upcast tensor (the non-simple module has been running with fp32 A the entire time).

Mamba2Simple stores A_log as nn.Parameter(torch.log(A).to(dtype=dtype))
— i.e., in the model's configured dtype (often bf16/fp16). The forward
pass then does A = -torch.exp(self.A_log) directly, so the log/exp
round-trip runs in low precision and feeds a quantised A into the SSD
kernels.

The non-simple Mamba2 module (mamba_ssm/modules/mamba2.py) is the
reference implementation for the same parameter and explicitly upcasts
before exp in both the forward and step paths:

    A = -torch.exp(self.A_log.float())  # (nheads) or (d_inner, d_state)

mamba_simple.py follows the same convention. Mamba2Simple appears to
be a direct reduction of mamba2.py and just missed the .float() cast.

Match the reference behaviour so Mamba2Simple produces numerically
consistent A values in mixed-precision runs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant