Skip to content

Add support for BLOOM #331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 3, 2023
Merged

Add support for BLOOM #331

merged 15 commits into from
Jul 3, 2023

Conversation

WoosukKwon
Copy link
Collaborator

Closes #61

This PR adds the BLOOM model and modifies the paged attention kernel to support ALiBi bias.

@WoosukKwon WoosukKwon requested a review from zhuohan123 July 2, 2023 07:35
@WoosukKwon
Copy link
Collaborator Author

@zhuohan123 I've fixed the PR to follow our new formatter. It should be ready for review now. Please take a look!

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you for your hard work! Left some of my questions about design choices. In addition, what's the speed difference between similar-size LLaMA and BLOOM?


float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition seems unnecessary. If alibi_slope == 0, then alibi_slope * (token_idx - context_len) will be 0 as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's to avoid the redundant computation of 0 * (token_idx - context_len).

@@ -53,13 +55,21 @@ def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")

def set_attn_bias(self, input_metadata: InputMetadata) -> None:
if input_metadata.attn_bias:
# Already set by a previous layer.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you choose this design, instead of explicitly initializing attn_bias in advance, say at the beginning of the forward function of BLOOM?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. It's because alibi_slopes is stored in the attention layer. If we want to create the attention bias in BloomForCausalLM, then we have to store alibi_slopes in both places (because alibi_slopes is also used for the decoding attention).

I kinda agree that this design is not ideal. But couldn't find a better way to do so.

@emsi emsi mentioned this pull request Jul 3, 2023
@WoosukKwon WoosukKwon merged commit e41f067 into main Jul 3, 2023
@WoosukKwon WoosukKwon deleted the bloom branch July 3, 2023 20:12
@Hukongtao
Copy link

I used vLLM try to speed up my BLOOM model, but found that the speed did not improve. Moreover, the memory usage of vLLM is higher, what may be the reason?
vLLM:
ad66d7b4-8f0f-46d7-85f0-60d6f1dcce86
HF:
778571a7-db1e-44ba-9dce-2864c7b598e9

@Hukongtao
Copy link

@WoosukKwon do you have some benchmarks about speed and memory with BLOOM?

hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Sep 30, 2024
vllm.utils.is_hpu() was redundant for some time now and has always been
problematic particularly for torch.compile mode. Now, we're fully
switching to current_platform.is_hpu().
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support BLOOM
3 participants