-
Notifications
You must be signed in to change notification settings - Fork 31.5k
[CB] Allow block sharing in hybrid models #42877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
e54a401 to
cc668d6
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Looks cool, let me know if I can help add some docs for it here! |
|
@stevhliu I am pushing a lot of new features to CB right now, will let you know when things settle down so we can do a big push on docs! Thanks! |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good! The perf don't show diff, they should for models that were hybrid no?
| self.use_prefix_sharing = allow_prefix_sharing and group_types == ["full_attention"] | ||
| self._block_manager = BlockManager(num_blocks, self.block_size, self.use_prefix_sharing) | ||
| # We only use prefix sharing if the whole model has only full attention layers and block sharing is allowed | ||
| self.use_prefix_sharing = allow_block_sharing and group_types == ["full_attention"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no rename on the self attr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question! prefix_sharing is only possible if block sharing (which is more of a memory optimization) is enabled AND the model has no sliding window layers: if there are any, they will create sliding window groups with no shareable blocks, hence no prefix sharing.
| for b in range(len(prompt_ids) // self.block_size): | ||
| tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size] | ||
| current_hash = self._block_manager.compute_hash(current_hash, tokens) | ||
| # Prefix sharing is only supported when there is only one full attention layer group, so group_id=0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the comment still valid? thought this PR allowed different groups to exist, and only acts on 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, because prefix sharing is still only activated if there is only one group -- cf comment above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but then to mark as complete we loop on the groups no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes because marking a block as complete is useful in the context of block sharing, which can happen in an hybrid model. But here we are in the context of prefix sharing, which is more restrictive, so we know there is only one group. Maybe I am missing something here
| a no-op.""" | ||
| num_complete_blocks = 0 if not self.use_prefix_sharing else self.blocks_to_complete.pop(state.request_id) | ||
| def mark_shareable_blocks_as_complete(self, state: RequestState) -> None: | ||
| """Marks the shareable blocks that have been computed in the forward pass as complete. If block sharing is off, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """Marks the shareable blocks that have been computed in the forward pass as complete. If block sharing is off, | |
| """Marks the shareable blocks that have been computed in the forward pass as complete (meaning it contains cache for tokens that are already processed, vs empty cache for futur new tokens). If block sharing is off, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not entirely, but I can see the confusion. Will adapt
|
|
||
|
|
||
| class Block: | ||
| class Block: # TODO: rename to ShareableBlock and update the docs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SlidingBlock(Block) are not sharable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, so they wont create this kind of object! Hence the proposed name change.
| # If the block is shareable, we keep track of the allocated blocks as partial blocks | ||
| if shareable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in general with transformers we don't want if else, we want 2 classes, 1 for sharable, on for non sharable. Splitting the logic by class usually scales better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is, for non shareable blocks, we dont need the block object at all: we just want the physical block_id. Hence no need to create a python object that will never be used and keep track of it. So less overhead!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok got it !
| self._uninit_block_ids.extend(blocks) | ||
|
|
||
| def mark_blocks_as_complete( | ||
| def mark_shareable_blocks_as_complete( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only sharable blocks can be marked as complete and only sharable blocks need this logic (same as above comment) how can we better "encapsulate" this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this will make more sense when Block is renamed to ShareableBlock. In that case, I think the function name makes sense. The issue is I renammed one but not the other imo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok!
| self._allow_block_sharing = allow_block_sharing | ||
| self._use_prefix_sharing = allow_block_sharing # approximation until the cache is created |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kinda a weird to have 2 that do the same thing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They don't! The block sharing boolean allows hybrid models to to block sharing. The use_prefix_sharing bool is updated after we know if the model is hybrid or not. Hence the comment about the approximation!
Not really sadly, this is more of a memory optimization. It will ensure the parallel decoding wont eat too much memory with hybrid models. The perf table is more here as a sanity check to make sure we did not incur any major overhead with these changes. |
Summary
This PR increases the granularity of KV cache sharing. Previously, sharing was only enabled for full-attention only models. With this PR, sharing is enabled for any model that has full attention layers. It can still be disabled with a flag.
This PR paves the way for parallel decoding, by making it more efficient in hybrid-architectures.
Performance
Tests
No new failures for the CB tests.
Sanity check
Ran the command
python examples/pytorch/continuous_batching.py --samples 20 --add-prefix --compile --compareand outputs were nearly the same and made sense.