Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Kernel ] Enable fp8-marlin for fbgemm-fp8 models #6606

Merged
merged 47 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
2426e29
stash
robertgshaw2-neuralmagic Jul 18, 2024
7665b7b
format
robertgshaw2-neuralmagic Jul 18, 2024
ef27613
tweak arg name
robertgshaw2-neuralmagic Jul 18, 2024
2f96157
fix test
robertgshaw2-neuralmagic Jul 18, 2024
e748554
format
robertgshaw2-neuralmagic Jul 18, 2024
3ef571b
working e2e with our cutlass kernels
robertgshaw2-neuralmagic Jul 19, 2024
ad83666
added fp8 gemm
robertgshaw2-neuralmagic Jul 19, 2024
eb7d48c
remove
robertgshaw2-neuralmagic Jul 19, 2024
90bd839
format
robertgshaw2-neuralmagic Jul 19, 2024
15cc823
Merge branch 'main' into turn-on-fp8-dyn-per-token
robertgshaw2-neuralmagic Jul 19, 2024
d064dd7
stash
robertgshaw2-neuralmagic Jul 19, 2024
6aa37e5
dynamic per token
robertgshaw2-neuralmagic Jul 19, 2024
c9d819a
format
robertgshaw2-neuralmagic Jul 19, 2024
08cbaf7
reenable cutlass
robertgshaw2-neuralmagic Jul 19, 2024
f4cdda1
cleanup comment
robertgshaw2-neuralmagic Jul 19, 2024
2971f4d
format
robertgshaw2-neuralmagic Jul 19, 2024
b601033
added dynamic per token test case
robertgshaw2-neuralmagic Jul 19, 2024
5d8edf9
Merge branch 'turn-on-fp8-dyn-per-token' into fbgemm-checkpoints
robertgshaw2-neuralmagic Jul 19, 2024
8b5d638
added use per token
robertgshaw2-neuralmagic Jul 19, 2024
006ccf0
format
Jul 19, 2024
1884acf
format
Jul 19, 2024
fe14072
Make optional ubs none
Jul 19, 2024
254dcff
format
Jul 19, 2024
919d866
Merge branch 'fp8-dpt-fpgemm' into fbgemm-checkpoints
robertgshaw2-neuralmagic Jul 19, 2024
227a277
hook up end to end with varun's ub quant kernel
robertgshaw2-neuralmagic Jul 19, 2024
951834a
formatted
robertgshaw2-neuralmagic Jul 19, 2024
9aa66d3
updated for nonuniform
robertgshaw2-neuralmagic Jul 19, 2024
458a410
formatting after passing prefix around
robertgshaw2-neuralmagic Jul 19, 2024
278f6d6
Merge branch 'main' into fbgemm-checkpoints
robertgshaw2-neuralmagic Jul 19, 2024
3e4aaad
fixed bad merge
robertgshaw2-neuralmagic Jul 19, 2024
de2a764
updated message
robertgshaw2-neuralmagic Jul 20, 2024
268fe94
Merge branch 'main' into fbgemm-checkpoints
robertgshaw2-neuralmagic Jul 20, 2024
c88fe34
merged varun's pr
robertgshaw2-neuralmagic Jul 20, 2024
bb02a3f
fixed
robertgshaw2-neuralmagic Jul 20, 2024
1c8f71c
cleanup pr
robertgshaw2-neuralmagic Jul 20, 2024
6970e50
Update config.py
robertgshaw2-neuralmagic Jul 20, 2024
94617f0
fixed config
robertgshaw2-neuralmagic Jul 20, 2024
f9d569c
updated for new ckpt format, turned on ada lovelace, and added test case
robertgshaw2-neuralmagic Jul 20, 2024
ae45615
format
robertgshaw2-neuralmagic Jul 20, 2024
e2a1eda
add marlin support to fbgemm
robertgshaw2-neuralmagic Jul 20, 2024
a4abc78
fix configs
robertgshaw2-neuralmagic Jul 20, 2024
5008ecb
fix configs
robertgshaw2-neuralmagic Jul 20, 2024
615a2ed
added marlin nonuniform test
robertgshaw2-neuralmagic Jul 20, 2024
7ea9025
Merge branch 'main' into fbgemm-fp8-marlin
robertgshaw2-neuralmagic Jul 20, 2024
da37598
fixed
robertgshaw2-neuralmagic Jul 20, 2024
a14116c
Merge branch 'main' into fbgemm-fp8-marlin
robertgshaw2-neuralmagic Jul 20, 2024
183bfe7
use marlin remove:
robertgshaw2-neuralmagic Jul 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
updated for nonuniform
  • Loading branch information
robertgshaw2-neuralmagic committed Jul 19, 2024
commit 9aa66d3565daceafc9fad7407f5339c5a27a5ce2
3 changes: 2 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
prefix: str = "",
) -> None:
super().__init__()
if cache_config is not None:
Expand All @@ -56,7 +57,7 @@ def __init__(
self._k_scale = 1.0
self._v_scale = 1.0
quant_method = quant_config.get_quant_method(
self) if quant_config else None
self, prefix=prefix) if quant_config else None
if quant_method is not None:
assert isinstance(quant_method, Fp8KVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
Expand Down
9 changes: 5 additions & 4 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()

Expand All @@ -155,7 +156,7 @@ def __init__(
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self)
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)

def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
Expand Down Expand Up @@ -184,7 +185,7 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
quant_config, prefix=prefix)

# All the linear layer supports quant method.
assert self.quant_method is not None
Expand Down Expand Up @@ -260,7 +261,7 @@ def __init__(self,
output_sizes: Optional[List[int]] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
quant_config, prefix)

self.gather_output = gather_output

Expand Down Expand Up @@ -709,7 +710,7 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
quant_config, prefix)

self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,12 @@ def get_from_keys_or(config: Dict[str, Any], keys: List[str],

@abstractmethod
def get_quant_method(
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
self, layer: torch.nn.Module, prefix: str) -> Optional[QuantizeMethodBase]:
"""Get the quantize method to use for the quantized layer.

Args:
layer: The layer for the quant method.
prefix: The full name of the layer in the state dict
Returns:
The quantize method. None if the given layer doesn't support quant
method.
Expand Down
53 changes: 48 additions & 5 deletions vllm/model_executor/layers/quantization/fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Expand All @@ -15,16 +15,28 @@
logger = init_logger(__name__)


# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}


class FBGEMMFp8Config(QuantizationConfig):
"""Config class for FBGEMM Fp8."""

def __init__(self, ignore_list: List[str]):
self.ignore_list = ignore_list

@classmethod
def get_name(cls) -> str:
return "fbgemm_fp8"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16]
return [torch.bfloat16, torch.float16]

@classmethod
def get_min_capability(cls) -> int:
Expand All @@ -36,11 +48,42 @@ def get_config_filenames(cls) -> List[str]:

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
return cls()

ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
return cls(ignore_list=ignore_list)

def _is_layer_skipped(self, prefix: str) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in _FUSED_LAYER_NAME_MAPPING:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name) for
shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name]
]

is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in self.ignore_list

if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision."
)
else:
is_skipped = prefix in self.ignore_list

assert is_skipped is not None
return is_skipped

def get_quant_method(
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if self._is_layer_skipped(prefix):
return UnquantizedLinearMethod()
return FBGEMMFp8LinearMethod(self)
return None

Expand Down
11 changes: 7 additions & 4 deletions vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class VocabParallelEmbedding(torch.nn.Module):
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
quant_config: quant config for the layer
prefix: full name of the layer in the state dict
""" # noqa: E501

def __init__(self,
Expand All @@ -169,7 +170,8 @@ def __init__(self,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()

# Keep the input dimensions.
Expand All @@ -195,7 +197,7 @@ def __init__(self,

linear_method = None
if quant_config is not None:
linear_method = quant_config.get_quant_method(self)
linear_method = quant_config.get_quant_method(self, prefix=prefix)
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method: QuantizeMethodBase = linear_method
Expand Down Expand Up @@ -382,9 +384,10 @@ def __init__(self,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config)
org_num_embeddings, padding_size, quant_config, prefix)
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
Expand Down