Skip to content

Commit

Permalink
Fix scoped linear all-reduce for starcoder model (#1432)
Browse files Browse the repository at this point in the history
  • Loading branch information
skavulya authored Oct 20, 2024
1 parent f98688d commit 3d047db
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
6 changes: 4 additions & 2 deletions optimum/habana/transformers/models/modeling_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,5 +164,7 @@ def all_reduce(self, input):
dist.inference_all_reduce(input, group=self.mp_group)

def post_all_reduce(self, input):
output = input + self.bias if (self.bias is not None) else input
return output
# inplace addition needed for correct results
if self.bias is not None:
input += self.bias
return input
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,19 @@

class GaudiStarcoder2MLP(Starcoder2MLP):
def pre_mlp_forward(self, x):
inputs = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
output = self.down_proj(inputs)
return output
x = self.c_fc(x)
x = self.act(x)
x = self.c_proj(x)
x = F.dropout(x, p=self.residual_dropout, training=self.training)
return x

def mlp_all_reduce(self, x):
if hasattr(self.down_proj, "all_reduce"):
self.down_proj.all_reduce(x)
if hasattr(self.c_proj, "all_reduce"):
self.c_proj.all_reduce(x)

def post_mlp_forward(self, x):
if hasattr(self.down_proj, "post_all_reduce"):
return self.down_proj.post_all_reduce(x)
if hasattr(self.c_proj, "post_all_reduce"):
return self.c_proj.post_all_reduce(x)
return x


Expand Down Expand Up @@ -431,13 +433,10 @@ def forward(
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
self.self_attn.attention_all_reduce(hidden_states)
hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual)
self.mlp.mlp_all_reduce(hidden_states)
hidden_states = self.post_mlp(hidden_states, residual)

outputs = (hidden_states,)

Expand Down

0 comments on commit 3d047db

Please sign in to comment.