Skip to content

[Executorch][llama] Enable quantized sdpa #9945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,20 @@ runtime.python_test(
":export_library",
],
)

runtime.python_test(
name = "quantized_sdpa_source_transform_test",
srcs = [
"source_transformation/test_quantized_sdpa.py",
],
preload_deps = [
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
"//executorch/extension/llm/custom_ops:custom_ops_aot_py",
],
deps = [
":custom_kv_cache",
":sdpa",
"//caffe2:torch",
"//executorch/examples/models/llama:llama_transformer",
],
)
5 changes: 5 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
replace_kv_cache_with_custom_kv_cache,
replace_kv_cache_with_quantized_kv_cache,
)

from .source_transformation.quantize import (
get_quant_embedding_transform,
get_quant_weight_transform,
Expand All @@ -77,6 +78,7 @@
replace_sdpa_with_coreml_sdpa,
replace_sdpa_with_custom_op,
replace_sdpa_with_flex_sdpa,
replace_sdpa_with_quantized_sdpa,
replace_sdpa_with_simple_sdpa,
)
from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb
Expand Down Expand Up @@ -1222,11 +1224,14 @@ def _get_source_transforms( # noqa

if args.use_sdpa_with_kv_cache:
transforms.append(replace_kv_cache_with_custom_kv_cache)
# todo: do this optionally
transforms.append(replace_sdpa_with_custom_op)

if args.quantize_kv_cache:
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
transforms.append(replace_kv_cache_with_quantized_kv_cache)
# Right now
transforms.append(replace_sdpa_with_quantized_sdpa)

if args.use_kv_cache:
if args.qnn:
Expand Down
63 changes: 44 additions & 19 deletions examples/models/llama/source_transformation/custom_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(
self.use_custom_update_cache_op = use_custom_update_cache_op
self.quantized_cache_dtype = torch.int8
self.cache_fp_type = torch.float32
self.return_float_values = True
self.max_context_length = max_context_length
cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
scale_shape = (max_batch_size, max_context_length, n_heads, 1)
self.register_buffer(
Expand All @@ -61,17 +63,17 @@ def __init__(
"v_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
)
self.register_buffer(
"k_cache_scales", torch.ones(scale_shape, dtype=torch.float64)
"k_cache_scales", torch.ones(scale_shape, dtype=torch.float32)
)
self.register_buffer(
"v_cache_scales", torch.ones(scale_shape, dtype=torch.float64)
"v_cache_scales", torch.ones(scale_shape, dtype=torch.float32)
)
if cache_type == QuantizedCacheType.AffineAsymmetric:
self.register_buffer(
"k_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64)
"k_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8)
)
self.register_buffer(
"v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64)
"v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8)
)

def _quantize(self, value):
Expand All @@ -91,20 +93,15 @@ def _quantize(self, value):
)
return quantized_value, scales, zero_points

def update(self, input_pos, k_val, v_val):
"""
k_val, v_val: [B, H, S, D]
return: [B, H, S, D]
However the storage is [B, S, H, D] so we incur transpose in, transpose out
This shall be removed by subsequent post-export graph pass
"""
k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)
# quantize current k_val and store it in the cache
def _quantize_and_update(self, input_pos, k_val, v_val):
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)

quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)

k_scales = k_scales.to(torch.float32)
k_zero_points = k_zero_points.to(self.quantized_cache_dtype)
v_scales = v_scales.to(torch.float32)
v_zero_points = v_zero_points.to(self.quantized_cache_dtype)

if self.use_custom_update_cache_op:
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
Expand All @@ -125,25 +122,30 @@ def update(self, input_pos, k_val, v_val):
self.v_cache_scales[:, input_pos] = v_scales
self.v_cache_zero_points[:, input_pos] = v_zero_points

def _update_and_return_float_values(self, input_pos, k_val, v_val):
self._quantize_and_update(input_pos, k_val, v_val)

k_out = torch.ops.quantized_decomposed.dequantize_per_token(
self.k_cache,
self.k_cache_scales,
self.k_cache_zero_points,
self.k_cache_scales.to(torch.float64),
self.k_cache_zero_points.to(torch.int64),
torch.iinfo(self.quantized_cache_dtype).min,
torch.iinfo(self.quantized_cache_dtype).max,
self.quantized_cache_dtype,
self.cache_fp_type,
)
v_out = torch.ops.quantized_decomposed.dequantize_per_token(
self.v_cache,
self.v_cache_scales,
self.v_cache_zero_points,
self.v_cache_scales.to(torch.float64),
self.v_cache_zero_points.to(torch.int64),
torch.iinfo(self.quantized_cache_dtype).min,
torch.iinfo(self.quantized_cache_dtype).max,
self.quantized_cache_dtype,
self.cache_fp_type,
)

# When returning float values we jsut use the last value
# instead of dequantized value.
start_pos = input_pos[0].item()
if self.use_custom_update_cache_op:
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
Expand All @@ -152,6 +154,29 @@ def update(self, input_pos, k_val, v_val):
k_out[:, input_pos] = k_val
v_out[:, input_pos] = v_val

return k_out, v_out

def _update_and_return_quantized_values(self, input_pos, k_val, v_val):
self._quantize_and_update(input_pos, k_val, v_val)

return self.k_cache, self.v_cache

def update(self, input_pos, k_val, v_val):
"""
k_val, v_val: [B, H, S, D]
return: [B, H, S, D]
However the storage is [B, S, H, D] so we incur transpose in, transpose out
This shall be removed by subsequent post-export graph pass
"""
k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)

if self.return_float_values:
k_out, v_out = self._update_and_return_float_values(input_pos, k_val, v_val)
else:
k_out, v_out = self._update_and_return_quantized_values(
input_pos, k_val, v_val
)
return k_out.transpose(1, 2), v_out.transpose(1, 2)

@classmethod
Expand Down
122 changes: 121 additions & 1 deletion examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

import torch

from executorch.examples.models.llama.attention import KVCache, SDPA
from executorch.examples.models.llama.attention import Attention, KVCache, SDPA

from .custom_kv_cache import QuantizedKVCache


class SDPACustom(torch.nn.Module):
Expand Down Expand Up @@ -76,6 +78,124 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
return module


class QuantizedSDPA(torch.nn.Module):
"""
A quantized version of the SDPA (Scaled Dot Product Attention) module.

This module implements attention computation using quantized key-value pairs
to reduce memory footprint and potentially improve performance. It works with
a QuantizedKVCache to store and retrieve quantized key-value tensors.

The quantization process converts floating point tensors to int8, which requires
maintaining scale and zero point values for proper dequantization during computation.

Args:
dim (int): The dimension of the model
kv_cache (QuantizedKVCache): The cache for storing quantized key-value pairs
Note that it needs to own kv_cache to access scales and zero points, and since
SDPA forward signature only accepts q, k and v, to allow accessing scales and
zero points, we need to pass kv_cache to SDPA.
"""

def __init__(self, dim: int, kv_cache: QuantizedKVCache):
super().__init__()
self.dim = dim
self.quantized_dtype = torch.int8
self.float_dtype = torch.float32
self.kv_cache = kv_cache

def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k_quantized: torch.Tensor,
v_quantized: torch.Tensor,
bsz,
seqlen,
mask,
):
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
k_quantized = k_quantized.transpose(1, 2)
v_quantized = v_quantized.transpose(1, 2)

q_scale, q_zero_point = (
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
q, self.quantized_dtype
)
)
q_quantized = torch.ops.quantized_decomposed.quantize_per_token(
q,
q_scale,
q_zero_point,
torch.iinfo(self.quantized_dtype).min,
torch.iinfo(self.quantized_dtype).max,
self.quantized_dtype,
)
q_zero_point_int8 = q_zero_point.to(dtype=torch.int8)
q_scale_fp32 = q_scale.to(dtype=torch.float32)

k_zero_point_int8 = self.kv_cache.k_cache_zero_points
k_scale_fp32 = self.kv_cache.k_cache_scales
v_zero_point_int8 = self.kv_cache.v_cache_zero_points
v_scale_fp32 = self.kv_cache.v_cache_scales

start_pos = input_pos[0].item()
output = torch.ops.llama.custom_quantized_sdpa(
q_quantized,
k_quantized,
v_quantized,
start_pos,
None,
0,
True,
None,
q_zero_point_int8,
q_scale_fp32,
k_zero_point_int8,
k_scale_fp32,
v_zero_point_int8,
v_scale_fp32,
)

return output.view(bsz, seqlen, self.dim)


def _update_attention_module_with_quantized_sdpa(
module: torch.nn.Module, kv_cache: QuantizedKVCache
):
sdpa = getattr(module, "SDPA", None)
assert sdpa is not None
# pyre-ignore
setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache)) # noqa: B010


def _replace_sdpa_with_quantized_sdpa(module: torch.nn.Module):
for _, child in module.named_children():
if isinstance(child, Attention):
kv_cache = getattr(child, "kv_cache", None)
if kv_cache is None:
continue
if not isinstance(kv_cache, QuantizedKVCache):
continue
# Only when kv_cache is QuantizedKVCache, we replace SDPA with QuantizedSDPA
sdpa = getattr(child, "SDPA", None)
if sdpa is None:
continue
if not isinstance(sdpa, SDPACustom):
continue
kv_cache.return_float_values = False
_update_attention_module_with_quantized_sdpa(child, kv_cache)
else:
_replace_sdpa_with_quantized_sdpa(child)


def replace_sdpa_with_quantized_sdpa(module: torch.nn.Module) -> torch.nn.Module:
from executorch.extension.llm.custom_ops import custom_ops # noqa

_replace_sdpa_with_quantized_sdpa(module)
return module


class SDPASimple(torch.nn.Module):
def __init__(
self,
Expand Down
Loading
Loading