diff --git a/optimum/habana/transformers/models/modeling_all_models.py b/optimum/habana/transformers/models/modeling_all_models.py index 90aa2d5e0..a76922024 100644 --- a/optimum/habana/transformers/models/modeling_all_models.py +++ b/optimum/habana/transformers/models/modeling_all_models.py @@ -164,5 +164,7 @@ def all_reduce(self, input): dist.inference_all_reduce(input, group=self.mp_group) def post_all_reduce(self, input): - output = input + self.bias if (self.bias is not None) else input - return output + # inplace addition needed for correct results + if self.bias is not None: + input += self.bias + return input diff --git a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py index c50af7ff2..b01a17636 100644 --- a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py +++ b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py @@ -61,17 +61,19 @@ class GaudiStarcoder2MLP(Starcoder2MLP): def pre_mlp_forward(self, x): - inputs = self.act_fn(self.gate_proj(x)) * self.up_proj(x) - output = self.down_proj(inputs) - return output + x = self.c_fc(x) + x = self.act(x) + x = self.c_proj(x) + x = F.dropout(x, p=self.residual_dropout, training=self.training) + return x def mlp_all_reduce(self, x): - if hasattr(self.down_proj, "all_reduce"): - self.down_proj.all_reduce(x) + if hasattr(self.c_proj, "all_reduce"): + self.c_proj.all_reduce(x) def post_mlp_forward(self, x): - if hasattr(self.down_proj, "post_all_reduce"): - return self.down_proj.post_all_reduce(x) + if hasattr(self.c_proj, "post_all_reduce"): + return self.c_proj.post_all_reduce(x) return x @@ -431,13 +433,10 @@ def forward( flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + self.self_attn.attention_all_reduce(hidden_states) + hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual) + self.mlp.mlp_all_reduce(hidden_states) + hidden_states = self.post_mlp(hidden_states, residual) outputs = (hidden_states,)