Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,47 @@ def _is_col_major(x: torch.Tensor) -> bool:
b, m, n = x.shape
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m

class PureFp8Config(QuantizationConfig):
"""Custom FP8 config for pure FP8 weights without scales."""

def __init__(self) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = True
self.activation_scheme = "none" # No activation quantization
self.ignored_layers = []
self.weight_block_size = None
self.use_tma_kernel = True # Flag to use TMA kernel

@classmethod
def get_name(cls) -> QuantizationMethods:
return "pure_fp8"

@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]

@classmethod
def get_min_capability(cls) -> int:
return 80

@classmethod
def get_config_filenames(cls) -> list[str]:
return []

@classmethod
def from_config(cls, config: dict[str, Any]) -> "PureFp8Config":
"""Create PureFp8Config from config dict - no additional config needed."""
return cls()

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention

if isinstance(layer, LinearBase):
return PureFp8LinearMethod(self)
return None



class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
Expand Down
201 changes: 196 additions & 5 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@
from vllm import envs
from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils import direct_register_custom_op
# Import your TMA kernel function
import sys
import os

# TODO: put tma_persistent_gemm into vllm
# Add the customized kernel directory to Python path (Docker container path)
customized_kernel_dir = "/.venv/lib/python3.12/site-packages/vllm/customized_kernel"
if customized_kernel_dir not in sys.path:
sys.path.insert(0, customized_kernel_dir)

# Also try local directory as fallback (for non-Docker execution)
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.insert(0, current_dir)

from tma_persistent_gemm import matmul_tma_persistent


def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -57,10 +73,10 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
"""
Applies penalties in place to the logits tensor
logits : The input logits tensor of shape [num_seqs, vocab_size]
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
are padded to the maximum prompt length within the batch using
`vocab_size` as the padding value. The value `vocab_size` is used
for padding because it does not correspond to any valid token ID
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
are padded to the maximum prompt length within the batch using
`vocab_size` as the padding value. The value `vocab_size` is used
for padding because it does not correspond to any valid token ID
in the vocabulary.
output_tokens_tensor: The output tokens tensor.
presence_penalties: The presence penalties of shape (num_seqs, )
Expand Down Expand Up @@ -187,11 +203,186 @@ def cpu_unquantized_gemm(layer: torch.nn.Module,
bias: Optional[torch.Tensor] = None):
return layer.cpu_linear(x, weight, bias)

def fp8_tma_linear(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
layer=None):
"""Optimized FP8 TMA linear function with CUDA graph capture support.

Args:
x: Input tensor, will be converted to FP8 if needed
weight: Weight tensor in FP8E4M3FN format (already transposed)
bias: Optional bias tensor
layer: Layer instance for accessing vLLM config and persistent buffers

Returns:
Output tensor in BF16 format
"""
try:
## Do not convert the input to FP8 here, convert it inside of the kernel
# if x.dtype != torch.float8_e4m3fn:
# x = x.to(torch.float8_e4m3fn)

# Weight should already be FP8 from PureFp8LinearMethod
assert weight.dtype == torch.float8_e4m3fn, f"Expected FP8 weight, got {weight.dtype}"

# Handle arbitrary input shapes like torch.nn.functional.linear
input_shape = x.shape
x_2d = x.view(-1, x.size(-1)) # Flatten to 2D for matmul
num_tokens = x_2d.shape[0]

# Initialize persistent buffers if not available
if layer is not None:
# Hardcoded batch sizes for CUDA graph capture

# Simple padding to next power of 2 up to max batch size
M_cap = num_tokens
for size in layer.cudagraph_batch_sizes:
if num_tokens <= size:
M_cap = size
break

# CUDA graph capture and replay for TMA persistent kernel
if M_cap <= layer.cudagraph_batch_sizes[-1]:
graph_key = M_cap # Simple batch size key since weight is fixed

if graph_key not in layer._tma_graphs:
# TMA kernel format preparation
if weight.shape[1] != x_2d.shape[1]:
weight_for_tma = weight.T
else:
weight_for_tma = weight

# Create input/output buffers for this batch size and weight
padded_input_shape = (M_cap, x_2d.shape[1])
output_shape = (M_cap, weight_for_tma.shape[0])

layer._tma_inputs[graph_key] = torch.zeros(
padded_input_shape, dtype=x.dtype, device=x.device)
layer._tma_outputs[graph_key] = torch.zeros(
output_shape, dtype=torch.bfloat16, device=x.device)
layer._tma_weights[graph_key] = weight_for_tma

# Create memory pool and stream for stable capture
# Check if memory_pool is available and functional
memory_pool = None
try:
if hasattr(torch.cuda, 'memory_pool'):
memory_pool = torch.cuda.memory_pool()
except Exception:
# Memory pool API exists but not functional, continue without it
memory_pool = None

stream = torch.cuda.Stream()

# Capture the graph
layer._tma_graphs[graph_key] = torch.cuda.CUDAGraph()
if memory_pool is not None:
with torch.cuda.graph(layer._tma_graphs[graph_key], pool=memory_pool, stream=stream):
layer._tma_outputs[graph_key] = matmul_tma_persistent(
layer._tma_inputs[graph_key],
layer._tma_weights[graph_key],
bias=bias
)
else:
# Fallback for older PyTorch versions without memory_pool
with torch.cuda.graph(layer._tma_graphs[graph_key], stream=stream):
layer._tma_outputs[graph_key] = matmul_tma_persistent(
layer._tma_inputs[graph_key],
layer._tma_weights[graph_key],
bias=bias
)

# Copy input data to buffer
if M_cap > num_tokens:
layer._tma_inputs[graph_key][:num_tokens] = x_2d
# layer._tma_inputs[graph_key][num_tokens:] = 0 ## remove this line to reduce the overhead
else:
layer._tma_inputs[graph_key].copy_(x_2d)

# Replay the graph
layer._tma_graphs[graph_key].replay()

# Get output and slice back to actual size if padded
if M_cap > num_tokens:
output_2d = layer._tma_outputs[graph_key][:num_tokens]
else:
output_2d = layer._tma_outputs[graph_key]
else:
# Fallback for large batch sizes - use TMA kernel without graph
if weight.shape[1] != x_2d.shape[1]:
weight_for_tma = weight.T
else:
weight_for_tma = weight

output_2d = matmul_tma_persistent(
x_2d, weight_for_tma,
bias=bias
)

else:
# Fallback for cases without layer context
# TMA kernel expects: a=(M,K), b=(N,K) where a.shape[1] == b.shape[1]
if weight.shape[1] != x_2d.shape[1]:
weight_for_tma = weight.T
else:
weight_for_tma = weight

# Use customized TMA persistent kernel for all cases
output_2d = matmul_tma_persistent(
x_2d, weight_for_tma,
bias=bias
)

# COMMENTED OUT: Previous cuBLAS FP8 _scaled_mm kernel without layer context
# scale_a = torch.tensor(1.0, device=x_2d.device, dtype=torch.float32)
# scale_b = torch.tensor(1.0, device=weight_t.device, dtype=torch.float32)
# output_2d = torch.ops.aten._scaled_mm(
# x_2d, weight_t,
# scale_a=scale_a,
# scale_b=scale_b,
# bias=bias,
# out_dtype=torch.bfloat16
# )

# Reshape back to match expected output shape
# For TMA kernel: weight_for_tma is (N, K), so N is at dimension 0
output_features = weight_for_tma.shape[0] if 'weight_for_tma' in locals() else weight.shape[0]
output_shape = input_shape[:-1] + (output_features,)
output = output_2d.view(output_shape)

return output
except ImportError as e:
# Fallback to regular linear if TMA kernel not available
print(f"Warning: TMA kernel not available ({e}), falling back to regular linear")
return torch.nn.functional.linear(x, weight, bias)
except Exception as e:
# Fallback for any other errors
print(f"Warning: TMA kernel failed ({e}), falling back to regular linear")
return torch.nn.functional.linear(x, weight, bias)


def unquantized_gemm_with_fp8_support(layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None):
"""Unquantized GEMM with FP8 TMA kernel support."""
# Check if weight is FP8 - if so, use TMA kernel
# print(weight.dtype)
# print(x.dtype)
if weight.dtype == torch.float8_e4m3fn:
# x = x.to(torch.float8_e4m3fn)
return fp8_tma_linear(x, weight, bias, layer=layer)
else:
# Use regular linear for non-FP8 weights
# return default_unquantized_gemm
return torch.nn.functional.linear(x, weight, bias)

def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
if current_platform.is_rocm():
return rocm_unquantized_gemm
elif current_platform.is_cpu():
return cpu_unquantized_gemm
else:
return default_unquantized_gemm
return unquantized_gemm_with_fp8_support
# return default_unquantized_gemm
13 changes: 9 additions & 4 deletions vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ class VocabParallelEmbedding(CustomOp):
padding_size: padding size for the vocabulary.
quant_config: quant config for the layer
prefix: full name of the layer in the state dict
tp_size: tensor parallel size, if set, use the given value.
""" # noqa: E501

def __init__(self,
Expand All @@ -213,12 +214,14 @@ def __init__(self,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
tp_size: int = None):
super().__init__()

# Keep the input dimensions.
tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_size = get_tensor_model_parallel_world_size() \
if tp_size is None else tp_size
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
Expand Down Expand Up @@ -447,6 +450,7 @@ class ParallelLMHead(VocabParallelEmbedding):
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
tp_size: tensor parallel size, if set, use the given value.
"""

def __init__(self,
Expand All @@ -457,10 +461,11 @@ def __init__(self,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
tp_size: int = None):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config,
prefix)
prefix, tp_size)
self.quant_config = quant_config
if bias:
self.bias = Parameter(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/llama_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def transform(inputs):
name, loaded_weight = inputs
if "lm_head" not in name:
name = "model." + name
if "lm_head" in name:
return
return name, loaded_weight

loader = AutoWeightsLoader(
Expand Down
Loading