Skip to content

Conversation

@remi-or
Copy link
Collaborator

@remi-or remi-or commented Dec 15, 2025

Summary

This PR increases the granularity of KV cache sharing. Previously, sharing was only enabled for full-attention only models. With this PR, sharing is enabled for any model that has full attention layers. It can still be disabled with a flag.
This PR paves the way for parallel decoding, by making it more efficient in hybrid-architectures.

Performance

Attention Version Generated tokens Duration (s) Throughput (tok/s)
Flash attention 3 This PR 113149 16.61 6811.56
Flash attention 3 Main branch 112599 16.73 6729.27
Flash attention 2 This PR 113670 24.83 4578.74
Flash attention 2 Main branch 112822 24.61 4584.74
SDPA This PR 112970 84.32 1339.76
SDPA Main branch 113254 82.49 1373.00

Tests

No new failures for the CB tests.

Sanity check

Ran the command python examples/pytorch/continuous_batching.py --samples 20 --add-prefix --compile --compare and outputs were nearly the same and made sense.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@stevhliu
Copy link
Member

Looks cool, let me know if I can help add some docs for it here!

@remi-or remi-or requested a review from ArthurZucker December 17, 2025 11:17
@remi-or
Copy link
Collaborator Author

remi-or commented Dec 17, 2025

@stevhliu I am pushing a lot of new features to CB right now, will let you know when things settle down so we can do a big push on docs! Thanks!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good! The perf don't show diff, they should for models that were hybrid no?

self.use_prefix_sharing = allow_prefix_sharing and group_types == ["full_attention"]
self._block_manager = BlockManager(num_blocks, self.block_size, self.use_prefix_sharing)
# We only use prefix sharing if the whole model has only full attention layers and block sharing is allowed
self.use_prefix_sharing = allow_block_sharing and group_types == ["full_attention"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no rename on the self attr?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! prefix_sharing is only possible if block sharing (which is more of a memory optimization) is enabled AND the model has no sliding window layers: if there are any, they will create sliding window groups with no shareable blocks, hence no prefix sharing.

for b in range(len(prompt_ids) // self.block_size):
tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size]
current_hash = self._block_manager.compute_hash(current_hash, tokens)
# Prefix sharing is only supported when there is only one full attention layer group, so group_id=0.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the comment still valid? thought this PR allowed different groups to exist, and only acts on 1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because prefix sharing is still only activated if there is only one group -- cf comment above

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but then to mark as complete we loop on the groups no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes because marking a block as complete is useful in the context of block sharing, which can happen in an hybrid model. But here we are in the context of prefix sharing, which is more restrictive, so we know there is only one group. Maybe I am missing something here

a no-op."""
num_complete_blocks = 0 if not self.use_prefix_sharing else self.blocks_to_complete.pop(state.request_id)
def mark_shareable_blocks_as_complete(self, state: RequestState) -> None:
"""Marks the shareable blocks that have been computed in the forward pass as complete. If block sharing is off,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Marks the shareable blocks that have been computed in the forward pass as complete. If block sharing is off,
"""Marks the shareable blocks that have been computed in the forward pass as complete (meaning it contains cache for tokens that are already processed, vs empty cache for futur new tokens). If block sharing is off,

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not entirely, but I can see the confusion. Will adapt



class Block:
class Block: # TODO: rename to ShareableBlock and update the docs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SlidingBlock(Block) are not sharable

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so they wont create this kind of object! Hence the proposed name change.

Comment on lines +117 to +118
# If the block is shareable, we keep track of the allocated blocks as partial blocks
if shareable:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general with transformers we don't want if else, we want 2 classes, 1 for sharable, on for non sharable. Splitting the logic by class usually scales better

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is, for non shareable blocks, we dont need the block object at all: we just want the physical block_id. Hence no need to create a python object that will never be used and keep track of it. So less overhead!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok got it !

self._uninit_block_ids.extend(blocks)

def mark_blocks_as_complete(
def mark_shareable_blocks_as_complete(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only sharable blocks can be marked as complete and only sharable blocks need this logic (same as above comment) how can we better "encapsulate" this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this will make more sense when Block is renamed to ShareableBlock. In that case, I think the function name makes sense. The issue is I renammed one but not the other imo

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!

Comment on lines +783 to +784
self._allow_block_sharing = allow_block_sharing
self._use_prefix_sharing = allow_block_sharing # approximation until the cache is created
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kinda a weird to have 2 that do the same thing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They don't! The block sharing boolean allows hybrid models to to block sharing. The use_prefix_sharing bool is updated after we know if the model is hybrid or not. Hence the comment about the approximation!

@remi-or
Copy link
Collaborator Author

remi-or commented Dec 17, 2025

sounds good! The perf don't show diff, they should for models that were hybrid no?

Not really sadly, this is more of a memory optimization. It will ensure the parallel decoding wont eat too much memory with hybrid models. The perf table is more here as a sanity check to make sure we did not incur any major overhead with these changes.

@remi-or remi-or merged commit 04e78e6 into main Dec 18, 2025
26 checks passed
@remi-or remi-or deleted the cb-block-sharing branch December 18, 2025 11:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants