Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix scoped linear all-reduce for starcoder model #1432

Merged
merged 2 commits into from
Oct 20, 2024

Conversation

skavulya
Copy link
Contributor

@skavulya skavulya commented Oct 17, 2024

A bug in scoped linear all-reduce implementation for starcoder2 model caused incorrect output as shown below:

python ../gaudi_spawn.py --use_deepspeed --world_size 2 run_generation.py --model_name_or_path bigcode/starcoder2-15b --use_hpu_graphs --trust_remote_code --attn_softmax_bf16 --trim_logits --use_kv_cache --use_flash_attention --flash_attention_recompute --max_new_tokens 128 --batch_size 1 --bf16 --prompt "def is_prime():"

Output:
Input/outputs:
input 1: ('def is_prime():',)
output 1: ('def is_prime(): ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( (',)

Output after the fix:
Input/outputs:
input 1: ('def is_prime():',)
output 1: ('def is_prime():\n for i in range(2, int(math.sqrt(n)) + 1):\n if n % i == 0:\n return False\n return True\n\ndef is_palindrome():\n return str(n) == str(n)[::-1]\n\ndef is_lychrel():\n for i in range(50):\n n = n + int(str(n)[::-1])\n if is_palindrome():\n return False\n return True\n\ncount = 0\nfor i in range(10000):\n if is_lychrel():',)

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@mandy-li mandy-li requested a review from libinta October 17, 2024 03:45
@mandy-li mandy-li added the run-test Run CI for PRs from external contributors label Oct 17, 2024
Comment on lines 67 to 68
output = F.dropout(x, p=self.residual_dropout, training=self.training)
return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output = F.dropout(x, p=self.residual_dropout, training=self.training)
return output
x= F.dropout(x, p=self.residual_dropout, training=self.training)
return x

@skavulya why do you create a new memory here, is there a reason you do not overwrite x?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -164,5 +164,7 @@ def all_reduce(self, input):
dist.inference_all_reduce(input, group=self.mp_group)

def post_all_reduce(self, input):
output = input + self.bias if (self.bias is not None) else input
return output
# inplace addition needed for correct results
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skavulya very nice catch!

Does this affect any other tests or models (falcon, llama, gemma, qwen2, qwen2_moe)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only looked at llama and qwen2. They were not affected because they didn't use a bias. I can test falcon, and qwen2 moe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skavulya if llama and qwen2 did not use a bias, then the output is wrong isn't it? since based on the original code output=input + input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For llama and qwen2, bias was None so the addition is skipped.

@skavulya
Copy link
Contributor Author

@yafshar The results of the transformer tests on main and this pr are the same

GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/transformers/tests/models/
main: commit f98688d ==== 41 failed, 993 passed, 401 skipped, 101 warnings in 1531.33s (0:25:31) ====
this pr: ==== 41 failed, 993 passed, 401 skipped, 101 warnings in 1296.93s (0:21:36) ====

Copy link
Contributor

@yafshar yafshar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@regisss this PR is ready. Would you check it.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

@regisss regisss merged commit 3d047db into huggingface:main Oct 20, 2024
2 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-test Run CI for PRs from external contributors
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants