-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[V1] Allocate kv_cache with stride order for V1 #18775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2033,9 +2033,29 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: | |
num_blocks, kv_cache_spec.block_size, | ||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) | ||
dtype = kv_cache_spec.dtype | ||
kv_caches[layer_name] = torch.zeros(kv_cache_shape, | ||
dtype=dtype, | ||
device=self.device) | ||
try: | ||
kv_cache_stride_order = self.attn_backends[ | ||
i].get_kv_cache_stride_order() | ||
assert len(kv_cache_stride_order) == len( | ||
kv_cache_shape) | ||
except (AttributeError, NotImplementedError): | ||
kv_cache_stride_order = tuple( | ||
range(len(kv_cache_shape))) | ||
Comment on lines
+2036
to
+2043
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The try/catch seems a little sketchy. but I don't see a better way. I thought about handling it in the abstract There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep same thought |
||
# The allocation respects the backend-defined stride order | ||
# to ensure the semantic remains consistent for each | ||
# backend. We first obtain the generic kv cache shape and | ||
# then permute it according to the stride order which could | ||
# result in a non-contiguous tensor. | ||
kv_cache_shape = tuple(kv_cache_shape[i] | ||
for i in kv_cache_stride_order) | ||
# Maintain original KV shape view. | ||
inv_order = [ | ||
kv_cache_stride_order.index(i) | ||
for i in range(len(kv_cache_stride_order)) | ||
] | ||
kv_caches[layer_name] = torch.zeros( | ||
kv_cache_shape, dtype=dtype, | ||
device=self.device).permute(*inv_order) | ||
else: | ||
# TODO: add new branches when introducing more types of | ||
# KV cache specs. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be a little better to check the exact stride order -- although this is ambiguous if one of the dims is length 1