Skip to content

Commit

Permalink
[ Misc ] fbgemm checkpoints (vllm-project#6559)
Browse files Browse the repository at this point in the history
Signed-off-by: Alvant <alvasian@yandex.ru>
  • Loading branch information
robertgshaw2-neuralmagic authored and Alvant committed Oct 26, 2024
1 parent 6979fd0 commit 66aee67
Show file tree
Hide file tree
Showing 24 changed files with 234 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.769
value: 0.752
- name: "exact_match,flexible-extract"
value: 0.769
value: 0.754
limit: 1000
num_fewshot: 5
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.753
- name: "exact_match,flexible-extract"
value: 0.753
limit: 1000
num_fewshot: 5
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
done

lm_eval --model vllm \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
--batch_size $BATCH_SIZE
2 changes: 2 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ def scaled_fp8_quant(
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
batch_dim_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
prefix: str = "",
) -> None:
super().__init__()
if cache_config is not None:
Expand All @@ -56,7 +57,7 @@ def __init__(
self._k_scale = 1.0
self._v_scale = 1.0
quant_method = quant_config.get_quant_method(
self) if quant_config else None
self, prefix=prefix) if quant_config else None
if quant_method is not None:
assert isinstance(quant_method, Fp8KVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _verify_quantization(self) -> None:
f"supported in ROCm.")
if (self.quantization
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin",
"compressed_tensors")):
"fbgemm_fp8", "compressed_tensors")):
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(
self.quant_method: Optional[QuantizeMethodBase] = (
UnquantizedFusedMoEMethod())
else:
self.quant_method = quant_config.get_quant_method(self)
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None

self.quant_method.create_weights(
Expand Down
26 changes: 16 additions & 10 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()

Expand All @@ -155,7 +156,8 @@ def __init__(
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self)
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)

def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
Expand All @@ -182,9 +184,13 @@ def __init__(self,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
prefix: str = ""):
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix)

# All the linear layer supports quant method.
assert self.quant_method is not None
Expand Down Expand Up @@ -258,9 +264,9 @@ def __init__(self,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
prefix: Optional[str] = None):
prefix: str = ""):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
quant_config, prefix)

self.gather_output = gather_output

Expand Down Expand Up @@ -370,7 +376,7 @@ def __init__(self,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
prefix: str = ""):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
Expand Down Expand Up @@ -514,7 +520,7 @@ def __init__(self,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
prefix: str = ""):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
Expand Down Expand Up @@ -707,9 +713,9 @@ def __init__(self,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = None):
prefix: str = ""):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
quant_config, prefix)

self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CompressedTensorsConfig)
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig)
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
Expand All @@ -24,6 +25,7 @@
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
return cls(in_group_size, nbits_per_codebook, num_code_books,
out_group_size)

def get_quant_method(
self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AQLMLinearMethod"]:
if isinstance(layer, LinearBase):
return AQLMLinearMethod(self)
return None
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)

def get_quant_method(
self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQLinearMethod"]:
if isinstance(layer, LinearBase):
return AWQLinearMethod(self)
return None
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,13 @@ def get_from_keys_or(config: Dict[str, Any], keys: List[str],
return default

@abstractmethod
def get_quant_method(
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional[QuantizeMethodBase]:
"""Get the quantize method to use for the quantized layer.
Args:
layer: The layer for the quant method.
prefix: The full name of the layer in the state dict
Returns:
The quantize method. None if the given layer doesn't support quant
method.
Expand Down
5 changes: 2 additions & 3 deletions vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
target_modules = cls.get_from_keys(config, ["target_modules"])
return cls(adapter_name, target_modules)

def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["BitsAndBytesLinearMethod"]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
if isinstance(layer, LinearBase):
return BitsAndBytesLinearMethod(self)
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ def get_min_capability(cls) -> int:
def get_name(self) -> str:
return "compressed_tensors"

# TODO (@robertgshaw2-neuralmagic): do layer skipping though here
# rather than though create_weights to match other methods
def get_quant_method(
self, layer: torch.nn.Module
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["CompressedTensorsLinearMethod"]:
if isinstance(layer, LinearBase):
return CompressedTensorsLinearMethod(self)
Expand Down
5 changes: 2 additions & 3 deletions vllm/model_executor/layers/quantization/deepspeedfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ def get_config_filenames() -> List[str]:
"quantize_config.json",
]

def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["DeepSpeedFPLinearMethod"]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["DeepSpeedFPLinearMethod"]:
if isinstance(layer, LinearBase):
return DeepSpeedFPLinearMethod(self)
return None
Expand Down
Loading

0 comments on commit 66aee67

Please sign in to comment.