Skip to content

Commit

Permalink
Remove ScaledActivation for AWQ (vllm-project#10057)
Browse files Browse the repository at this point in the history
Signed-off-by: mgoin <michael@neuralmagic.com>
  • Loading branch information
mgoin authored Nov 6, 2024
1 parent 406d4cc commit 399c798
Show file tree
Hide file tree
Showing 34 changed files with 19 additions and 124 deletions.
37 changes: 4 additions & 33 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import LazyDict

Expand Down Expand Up @@ -277,28 +276,14 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
})


def get_act_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
def get_act_fn(act_fn_name: str) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")

act_fn = _ACTIVATION_REGISTRY[act_fn_name]
if (quant_config is not None
and act_fn_name in quant_config.get_scaled_act_names()):
if intermediate_size is None:
raise ValueError("intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn
return _ACTIVATION_REGISTRY[act_fn_name]


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
Expand All @@ -307,25 +292,11 @@ def get_act_fn(
})


def get_act_and_mul_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")

act_fn = _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
if (quant_config is not None
and act_fn_name in quant_config.get_scaled_act_names()):
if intermediate_size is None:
raise ValueError("intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn
return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,6 @@ def get_quant_method(self, layer: torch.nn.Module,
return AQLMLinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class AQLMLinearMethod(LinearMethodBase):
"""Linear method for AQLM.
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ def get_quant_method(self, layer: torch.nn.Module,
return AWQLinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]


def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,6 @@ def get_quant_method(self, layer: torch.nn.Module,
return AWQMoEMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []

@classmethod
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
Expand Down
8 changes: 0 additions & 8 deletions vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,3 @@ def get_quant_method(self, layer: torch.nn.Module,
method.
"""
raise NotImplementedError

@abstractmethod
def get_scaled_act_names(self) -> List[str]:
"""Returns the activation function names that should be post-scaled.
For now, this is only used by AWQ.
"""
raise NotImplementedError
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,6 @@ def get_quant_method(self, layer: torch.nn.Module,
return BitsAndBytesLinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
# Split the prefix into its dot-separated components
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ def __init__(self,
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)

def get_scaled_act_names(self) -> List[str]:
return []

def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]

Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/deepspeedfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
def get_linear_method(self) -> "DeepSpeedFPLinearMethod":
return DeepSpeedFPLinearMethod(self)

def get_scaled_act_names(self) -> List[str]:
return []

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/experts_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ def get_quant_method(self, layer: torch.nn.Module,
return ExpertsInt8MoEMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class ExpertsInt8MoEMethod(FusedMoEMethodBase):

Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ def get_quant_method(self, layer: torch.nn.Module,
return FBGEMMFp8LinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class FBGEMMFp8LinearMethod(LinearMethodBase):

Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ def get_quant_method(self, layer: torch.nn.Module,
return Fp8KVCacheMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ def get_quant_method(self, layer: torch.nn.Module,
return GGUFEmbeddingMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor:
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@ def get_quant_method(self, layer: torch.nn.Module,
return GPTQLinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class ExllamaState(Enum):

Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,6 @@ def get_quant_method(
return GPTQMarlinMoEMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []

@classmethod
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/gptq_marlin_24.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,6 @@ def get_quant_method(self, layer: torch.nn.Module,
return GPTQMarlin24LinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class GPTQMarlin24LinearMethod(LinearMethodBase):
"""Linear method for Marlin24.
Expand Down
6 changes: 0 additions & 6 deletions vllm/model_executor/layers/quantization/ipex_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,6 @@ def get_quant_method(self, layer: torch.nn.Module,
return self.quant_method(self)
return None

def get_scaled_act_names(self) -> List[str]:
if self.method == "awq":
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
else:
return []


class IPEXAWQLinearMethod(AWQLinearMethod):
"""AWQ linear method using IPEX for the CPU backend.
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ def get_quant_method(self, layer: torch.nn.Module,
return MarlinLinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class MarlinLinearMethod(LinearMethodBase):
"""Linear method for Marlin.
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ def get_quant_method(self, layer: torch.nn.Module,
return ModelOptFp8KVCacheMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
"""
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/neuron_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]:
"Neuron Quantization is only supported through"
" transformers_neuronx.")

def get_scaled_act_names(self) -> List[str]:
return []

def get_quantization_config(self):
from transformers_neuronx.config import QuantizationConfig
return QuantizationConfig(quant_dtype=self.quant_dtype,
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/qqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,6 @@ def get_quant_method(self, layer: torch.nn.Module,
return QQQLinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class QQQLinearMethod(LinearMethodBase):
"""Linear method for QQQ.
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/quantization/tpu_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ def get_quant_method(self, layer: Module,
return TPUInt8LinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class TPUInt8LinearMethod(LinearMethodBase):
"""Int8 Linear method for TPU Quant. """
Expand Down
8 changes: 3 additions & 5 deletions vllm/model_executor/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,7 @@ def __init__(
cache_config=cache_config,
quant_config=quant_config)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = get_act_fn(config.activation_function,
quant_config)
self.activation_fn = get_act_fn(config.activation_function)

ffn_hidden_size = self.embed_dim
ffn_intermediate_size = config.encoder_ffn_dim
Expand All @@ -405,7 +404,7 @@ def __init__(
bias=ffn_has_bias,
quant_config=quant_config,
)
self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size)
self.act = get_act_fn("gelu")
self.fc2 = RowParallelLinear(
ffn_intermediate_size,
ffn_hidden_size,
Expand Down Expand Up @@ -473,8 +472,7 @@ def __init__(
config=config,
cache_config=cache_config,
quant_config=quant_config)
self.activation_fn = get_act_fn(config.activation_function,
quant_config)
self.activation_fn = get_act_fn(config.activation_function)

self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
'''
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(
4 * hidden_size,
quant_config=quant_config,
)
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.gelu_impl = get_act_fn("gelu")
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def __init__(
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config)
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.act = get_act_fn("gelu")
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.dense_4h_to_h = RowParallelLinear(
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
self.act = get_act_fn(config.activation_function)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ def __init__(
bias=True,
quant_config=quant_config,
)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
self.act = get_act_fn(config.activation_function)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/gpt_j.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def __init__(
hidden_size,
quant_config=quant_config,
)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
self.act = get_act_fn(config.activation_function)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc_in(hidden_states)
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ def __init__(
config.hidden_size,
quant_config=quant_config,
)
self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size)
self.act = get_act_fn(config.hidden_act)

def forward(self, hidden_states):
hidden_states, _ = self.dense_h_to_4h(hidden_states)
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
bias=not config.no_bias,
quant_config=quant_config,
)
self.act = get_act_fn("gelu", quant_config, intermediate_size)
self.act = get_act_fn("gelu")
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
self.activation_fn = get_act_fn(config.activation_function)
self.fc2 = RowParallelLinear(
config.ffn_dim,
self.embed_dim,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self,
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size,
quant_config=quant_config)
self.act = get_act_fn(config.hidden_act, quant_config)
self.act = get_act_fn(config.hidden_act)

def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.dense_h_to_4h(hidden_states)
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self,
config.hidden_size,
quant_config=quant_config,
)
self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
self.act = get_act_fn(config.hidden_act)

def forward(self, hidden_states):
hidden_states, _ = self.fc1(hidden_states)
Expand Down
Loading

0 comments on commit 399c798

Please sign in to comment.