Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 103 additions & 25 deletions gptqmodel/hf_minimax_m2/modeling_minimax_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def __init__(self, config: MiniMaxM2Config) -> None:
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate = self.act_fn(self.w1(hidden_states))
up = self.w3(hidden_states)
hidden_states = gate * up
hidden_states = self.w2(hidden_states)
return hidden_states
gate.mul_(up)
del up
return self.w2(gate)


class MiniMaxM2SparseMoeBlock(nn.Module):
Expand Down Expand Up @@ -168,7 +168,8 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
1.0 - self.jitter_noise,
1.0 + self.jitter_noise,
)
hidden_states = hidden_states * noise
hidden_states.mul_(noise)
del noise

hidden_states = hidden_states.view(-1, hidden_dim)
gate_dtype = self.gate.weight.dtype
Expand All @@ -188,7 +189,7 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens

if correction_bias is not None:
original_scores = scores
scores = scores + correction_bias
scores.add_(correction_bias)
else:
original_scores = scores
topk_scores: torch.Tensor
Expand Down Expand Up @@ -216,24 +217,42 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
routing_weights = original_scores.gather(1, selected_experts)
else:
routing_weights = topk_scores
del scores, original_scores, topk_scores

routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True).clamp(min=1e-12)
routing_weights.div_(routing_weights.sum(dim=-1, keepdim=True).clamp(min=1e-12))
if self.routed_scaling_factor != 1.0:
routing_weights = routing_weights * self.routed_scaling_factor
routing_weights.mul_(self.routed_scaling_factor)
routing_weights = routing_weights.to(hidden_states.dtype)
selected_experts = selected_experts.to(torch.long)

final_hidden_states = torch.zeros_like(hidden_states)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
del selected_experts
expert_hit = torch.nonzero(expert_mask.sum(dim=(-1, -2)) > 0, as_tuple=False).flatten()

# To further reduce memory, process tokens routed to each expert in chunks
# instead of all at once. A chunk size of 1024 is a reasonable default.
EXPERT_CHUNK_SIZE = 1024

for expert_idx in expert_hit.tolist():
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
token_states = hidden_states.index_select(0, top_x)
expert_output = expert_layer(token_states) * routing_weights[top_x, idx].unsqueeze(-1)
final_hidden_states.index_add_(0, top_x, expert_output.to(final_hidden_states.dtype))
idx_full, top_x_full = torch.where(expert_mask[expert_idx].squeeze(0))

for i in range(0, top_x_full.size(0), EXPERT_CHUNK_SIZE):
top_x = top_x_full[i : i + EXPERT_CHUNK_SIZE]
idx = idx_full[i : i + EXPERT_CHUNK_SIZE]

token_states = hidden_states.index_select(0, top_x)
expert_output = expert_layer(token_states)

weights = routing_weights[top_x, idx].unsqueeze(-1)
expert_output.mul_(weights)

final_hidden_states.index_add_(0, top_x, expert_output.to(final_hidden_states.dtype))
del expert_output, token_states, idx, top_x, weights

del idx_full, top_x_full
del hidden_states, routing_weights, expert_mask, expert_hit
final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
return final_hidden_states, router_logits

Expand Down Expand Up @@ -302,11 +321,15 @@ def forward(
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device

# projections
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
del hidden_states

# optional QK normalization
if self.use_qk_norm:
q_flat = query_states.transpose(1, 2).reshape(bsz * q_len, -1)
k_flat = key_states.transpose(1, 2).reshape(bsz * q_len, -1)
Expand All @@ -315,6 +338,7 @@ def forward(
query_states = q_flat.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = k_flat.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

# rotary embeddings
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
Expand All @@ -326,34 +350,88 @@ def forward(
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)

# handle cache
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
query_dtype = query_states.dtype
key_len = key_states.shape[-2]

# precompute sliding-window mask
window_mask = None
if self.sliding_window is not None and past_key_values is None:
query_positions = torch.arange(q_len, device=hidden_states.device).view(1, 1, q_len, 1)
key_positions = torch.arange(key_states.shape[-2], device=hidden_states.device).view(1, 1, 1, -1)
window_mask = key_positions < (query_positions - self.sliding_window)
if window_mask.any():
attn_weights = attn_weights.masked_fill(window_mask, float("-inf"))
q_pos = torch.arange(q_len, device=device).view(1, 1, q_len, 1)
k_pos = torch.arange(key_len, device=device).view(1, 1, 1, key_len)
wm = k_pos < (q_pos - self.sliding_window)
if wm.any():
window_mask = wm.squeeze(1) # (1, q_len, key_len)
del q_pos, k_pos, wm

attn_output_parts = []
attn_weights_list = [] if output_attentions else None

for h in range(self.num_heads):
# (bsz, q_len, key_len)
q = query_states[:, h, :, :]
k = key_states[:, h, :, :]
v = value_states[:, h, :, :]

# Chunked attention computation to reduce peak memory usage
out_parts = []
attn_parts = [] if output_attentions else None

# A smaller chunk size reduces memory but may be slightly slower
chunk_size = 1024
for i in range(0, q.size(1), chunk_size):
q_chunk = q[:, i:i + chunk_size, :]

# attn_chunk has shape (bsz, chunk_size, key_len)
attn_chunk = torch.matmul(q_chunk, k.transpose(-2, -1))
attn_chunk.mul_(self.scaling)

# Apply masks to the chunk
if attention_mask is not None:
attn_chunk.add_(attention_mask.squeeze(1)[:, i:i + chunk_size, :])

if window_mask is not None:
attn_chunk.masked_fill_(window_mask[:, i:i + chunk_size, :], float("-inf"))

attn_chunk = torch.softmax(attn_chunk, dim=-1, dtype=torch.float32).to(query_dtype)

if self.training and self.attention_dropout > 0:
attn_chunk = F.dropout(attn_chunk, p=self.attention_dropout, training=True)

if output_attentions:
attn_parts.append(attn_chunk)

# output_chunk has shape (bsz, chunk_size, head_dim)
out_chunk = torch.matmul(attn_chunk, v)
out_parts.append(out_chunk)

del q_chunk, attn_chunk, out_chunk

out = torch.cat(out_parts, dim=1)
attn_output_parts.append(out)

if output_attentions:
attn = torch.cat(attn_parts, dim=1)
attn_weights_list.append(attn)
del attn, attn_parts

del q, k, v, out, out_parts

attn_output = torch.stack(attn_output_parts, dim=1)
del attn_output_parts, query_states, key_states, value_states

attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
if self.training and self.attention_dropout > 0:
attn_weights = F.dropout(attn_weights, p=self.attention_dropout)
attn_weights = torch.stack(attn_weights_list, dim=1) if output_attentions else None

attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None
return attn_output, attn_weights


Expand Down