-
Notifications
You must be signed in to change notification settings - Fork 193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix scoped linear all-reduce for starcoder model #1432
Conversation
output = F.dropout(x, p=self.residual_dropout, training=self.training) | ||
return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output = F.dropout(x, p=self.residual_dropout, training=self.training) | |
return output | |
x= F.dropout(x, p=self.residual_dropout, training=self.training) | |
return x |
@skavulya why do you create a new memory here, is there a reason you do not overwrite x?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I will overwrite x similar to the original implementation
@@ -164,5 +164,7 @@ def all_reduce(self, input): | |||
dist.inference_all_reduce(input, group=self.mp_group) | |||
|
|||
def post_all_reduce(self, input): | |||
output = input + self.bias if (self.bias is not None) else input | |||
return output | |||
# inplace addition needed for correct results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@skavulya very nice catch!
Does this affect any other tests or models (falcon, llama, gemma, qwen2, qwen2_moe)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only looked at llama and qwen2. They were not affected because they didn't use a bias. I can test falcon, and qwen2 moe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@skavulya if llama and qwen2 did not use a bias, then the output is wrong isn't it? since based on the original code output=input + input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For llama and qwen2, bias was None so the addition is skipped.
23d91c2
to
30b5b12
Compare
@yafshar The results of the transformer tests on main and this pr are the same
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@regisss this PR is ready. Would you check it.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
A bug in scoped linear all-reduce implementation for starcoder2 model caused incorrect output as shown below:
python ../gaudi_spawn.py --use_deepspeed --world_size 2 run_generation.py --model_name_or_path bigcode/starcoder2-15b --use_hpu_graphs --trust_remote_code --attn_softmax_bf16 --trim_logits --use_kv_cache --use_flash_attention --flash_attention_recompute --max_new_tokens 128 --batch_size 1 --bf16 --prompt "def is_prime():"
Output:
Input/outputs:
input 1: ('def is_prime():',)
output 1: ('def is_prime(): ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( (',)
Output after the fix:
Input/outputs:
input 1: ('def is_prime():',)
output 1: ('def is_prime():\n for i in range(2, int(math.sqrt(n)) + 1):\n if n % i == 0:\n return False\n return True\n\ndef is_palindrome():\n return str(n) == str(n)[::-1]\n\ndef is_lychrel():\n for i in range(50):\n n = n + int(str(n)[::-1])\n if is_palindrome():\n return False\n return True\n\ncount = 0\nfor i in range(10000):\n if is_lychrel():',)
Fixes # (issue)
Before submitting