Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix/Core] Remove assertion for Flashinfer k_scale and v_scale #9861

Merged
merged 4 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[Bugfix/Core] Remove assertion for Flashinfer k_scale and v_scale
  • Loading branch information
pavanimajety committed Nov 1, 2024
commit 16e816ce5862ea44868d78ee517414884604ce84
21 changes: 14 additions & 7 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,19 +258,20 @@ def test_reshape_and_cache_flash(
del key_caches
del value_caches

k_scale = key.amax().item() / 256
v_scale = value.amax().item() / 256

# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache)
ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache)
ops.convert_fp8(cloned_value_cache, value_cache, v_scale,
kv_cache_dtype)
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()

# Using default kv_scale
k_scale = v_scale = 1.0

# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
Expand All @@ -281,9 +282,15 @@ def test_reshape_and_cache_flash(

if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(result_key_cache, key_cache)
ops.convert_fp8(result_key_cache,
key_cache,
k_scale,
kv_dtype=kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(result_value_cache, value_cache)
ops.convert_fp8(result_value_cache,
value_cache,
v_scale,
kv_dtype=kv_cache_dtype)

# Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
Expand Down
97 changes: 34 additions & 63 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,6 @@ class FlashInferMetadata(AttentionMetadata):
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = 1

use_cuda_graph: bool = True

Expand Down Expand Up @@ -340,7 +335,6 @@ def begin_forward(self):
assert self.paged_kv_last_page_len is not None
assert self.block_table_bound is not None
assert self.seq_lens_tensor is not None
self.query_start_loc = self.query_start_loc[:self.num_prefills + 1]
batch_size = self.query_start_loc.shape[0] - 1
assert batch_size >= 0
# We will use flash attention for profiling to
Expand All @@ -355,13 +349,11 @@ def begin_forward(self):
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.query_start_loc,
self.paged_kv_indptr[:self.num_prefills + 1],
self.paged_kv_indices,
self.paged_kv_last_page_len[:self.num_prefills],
self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len,
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
if self.num_decode_tokens > 0:
else:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
Expand All @@ -378,9 +370,9 @@ def begin_forward(self):
assert self.decode_wrapper is not None
self.decode_wrapper.end_forward()
self.decode_wrapper.begin_forward(
self.paged_kv_indptr[self.num_prefills:],
self.paged_kv_indptr,
self.paged_kv_indices,
self.paged_kv_last_page_len[self.num_prefills:],
self.paged_kv_last_page_len,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
Expand All @@ -405,14 +397,21 @@ def asdict_zerocopy(self,

@property
def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
if self.num_prefills == 0:
return None
return self
# Currently chunked prefill is not supported
if self.num_decode_tokens == 0:
assert self.num_prefills > 0
return self

return None

@property
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
if self.num_decode_tokens == 0:
# Currently chunked prefill is not supported
if self.num_prefills > 0:
assert self.num_decode_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
return None

return self

def advance_step(self,
Expand Down Expand Up @@ -600,12 +599,11 @@ def build(self, seq_lens: List[int], query_lens: List[int],

max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
decode_query_len = max(query_lens[self.num_prefills:], default=1)

if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens
num_decode_tokens = batch_size

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
Expand Down Expand Up @@ -691,7 +689,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
self.runner.kv_cache_dtype, self.runner.model_config.dtype)

return FlashInferMetadata(
decode_query_len=decode_query_len,
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
Expand Down Expand Up @@ -759,8 +756,6 @@ def forward(
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashInfer.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
Expand Down Expand Up @@ -812,6 +807,12 @@ def unified_flash_infer(
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)

if attn_metadata.num_prefill_tokens > 0:
assert attn_metadata.num_decode_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if attn_metadata.num_decode_tokens > 0:
assert attn_metadata.num_prefill_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if kv_cache.numel() > 0:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
Expand All @@ -831,33 +832,14 @@ def unified_flash_infer(
kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)

num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
query = query.contiguous() # Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached.
# QKV for prefill.
decode_query = query[num_prefill_tokens:]
query = query[:num_prefill_tokens]

key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]

assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens

prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache.numel() == 0:
prefill_output = flash_attn_varlen_func(
output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
Expand All @@ -873,34 +855,23 @@ def unified_flash_infer(
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
prefill_output = prefill_meta.prefill_wrapper.forward(
query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True)
if decode_meta := attn_metadata.decode_metadata:
output = prefill_meta.prefill_wrapper.forward(
query,
kv_cache,
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=k_scale,
v_scale=v_scale)
else:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
decode_output = attn_metadata.decode_metadata.decode_wrapper.forward(
decode_query,
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale)

if prefill_output is None and decode_output is not None:
# Decode only batch.
output, num_tokens = decode_output, num_decode_tokens
elif decode_output is None and prefill_output is not None:
# Prefill only batch.
output, num_tokens = prefill_output, num_prefill_tokens
else:
# Chunked prefill batch does not work with speculative decoding in
# FlashInfer backend, so the query length for decode should be 1.
assert prefill_output is not None
assert decode_output is not None
assert decode_meta is not None
assert decode_meta.decode_query_len == 1
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)


Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,11 @@ def create_weights(
layer.register_parameter("input_scale", scale)

def process_weights_after_loading(self, layer: Module) -> None:
max_w_scale, weight = requantize_with_max_scale(
layer.weight, layer.weight_scale, layer.logical_widths)
weight = layer.weight
max_w_scale = layer.weight_scale.max()
if not (layer.weight_scale == layer.weight_scale[0]).all():
max_w_scale, weight = requantize_with_max_scale(
layer.weight, layer.weight_scale, layer.logical_widths)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(),
Expand Down