diff --git a/CMakeLists.txt b/CMakeLists.txt index 35846fd1cfa99..b668cbc97de15 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" diff --git a/csrc/ops.h b/csrc/ops.h index f5e0e423bb65d..b839eaf0d26c8 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -93,6 +93,9 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, #endif +void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input, + float scale); + void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor lookup_table); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index cba07f0ae9f2a..cdbec4a34d77f 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -67,6 +67,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Aligning the number of tokens to be processed by each expert such " "that it is divisible by the block size."); + ops.def("static_scaled_int8_quant", &static_scaled_int8_quant, + "Compute int8 quantized tensor for given scaling factor"); + // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); cache_ops.def("swap_blocks", &swap_blocks, diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu new file mode 100644 index 0000000000000..4902e4c23434c --- /dev/null +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -0,0 +1,59 @@ +#include +#include +#include + +#include "../../dispatch_utils.h" + +static inline __device__ int8_t float_to_int8_rn(float x) { +#ifdef USE_ROCM + static const float i8_min = + static_cast(std::numeric_limits::min()); + static const float i8_max = + static_cast(std::numeric_limits::max()); + // round + float dst = std::nearbyint(x); + // saturate + dst = std::clamp(dst, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +namespace vllm { + +template +__global__ void static_scaled_int8_quant_kernel( + const scalar_t* __restrict__ input, int8_t* __restrict__ out, + scale_type scale, const int hidden_size) { + const int tid = threadIdx.x; + const int token_idx = blockIdx.x; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = + float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); + } +} +} // namespace vllm + +void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + float scale) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { + vllm::static_scaled_int8_quant_kernel + <<>>(input.data_ptr(), + out.data_ptr(), scale, + hidden_size); + }); +} diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py new file mode 100644 index 0000000000000..b9aa00ce13f56 --- /dev/null +++ b/tests/kernels/test_int8_quant.py @@ -0,0 +1,31 @@ +import pytest +import torch + +from vllm._C import ops + +DTYPES = [torch.half, torch.bfloat16, torch.float] +HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing +NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +SEEDS = [0] +SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, + seed: int, scale: float) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + + out1 = (x / scale).round().clamp( + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + ops.static_scaled_int8_quant(out2, x, scale) + assert torch.allclose(out1, out2, + atol=1) # big atol to account for rounding errors diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py new file mode 100644 index 0000000000000..b83286992da3d --- /dev/null +++ b/tests/quantization/test_compressed_tensors.py @@ -0,0 +1,36 @@ +"""Test model set-up and weight loading for sparseml-quantized models. + +Run `pytest tests/quantization/test_compressed_tensors.py`. +""" + +import torch + +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor) + + +def test_compressed_tensors_w8a8_static_setup(vllm_runner): + model_path = "nm-testing/tinyllama-one-shot-static-quant-test-compressed" + llm = vllm_runner(model_path, quantization="sparseml", enforce_eager=True) + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(gate_up_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod) + + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor) + + assert qkv_proj.weight.dtype is torch.int8 + assert o_proj.weight.dtype is torch.int8 + assert gate_up_proj.weight.dtype is torch.int8 + + assert qkv_proj.weight_scale.shard_splitter is not None + assert qkv_proj.weight_scale.logical_widths is not None + assert qkv_proj.input_scale.dtype is torch.float32 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9e7d0d96bf004..f0fab4d8aa26d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -251,6 +251,24 @@ def scaled_fp8_quant( return output, scale +# int8 +def static_scaled_int8_quant(input: torch.Tensor, + scale: float) -> torch.Tensor: + """ + Quantize the input tensor to int8 and return the quantized tensor. + + Args: + input: The input tensor to be quantized to int8. + scale: Scaling factor for the int8 quantization. + + Returns: + torch.Tensor: Output tensor in int8. + """ + q = torch.empty_like(input, dtype=torch.int8) + vllm_ops.static_scaled_int8_quant(q, input, scale) + return q + + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 4fcc7eee09cde..0a26cadf90bb4 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -58,7 +58,6 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """Apply the weights in layer to the input tensor. - Expects create_weights to have been called before on the layer.""" raise NotImplementedError @@ -79,8 +78,7 @@ def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - output_size_per_partition = sum(output_partition_sizes) - weight = Parameter(torch.empty(output_size_per_partition, + weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, dtype=params_dtype), requires_grad=False) @@ -151,15 +149,13 @@ class ReplicatedLinear(LinearBase): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -212,17 +208,15 @@ class ColumnParallelLinear(LinearBase): the list would be size 3. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - output_sizes: Optional[List[int]] = None, - ): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[List[int]] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -230,18 +224,26 @@ def __init__( # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, tp_size) + assert self.quant_method is not None + self.output_size_per_partition = divide(self.output_size, tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, tp_size) + for output_size in self.output_sizes + ] + if output_sizes is None: output_sizes = [output_size] - # All the linear layer supports quant method. - assert self.quant_method is not None - self.quant_method.create_weights(self, - self.input_size, - [x // tp_size for x in output_sizes], - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -323,24 +325,26 @@ class MergedColumnParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_sizes: List[int], - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): self.output_sizes = output_sizes # UPSTREAM SYNC: needed for LazyCompressedParameter self.loaded_shards: Set[int] = set() tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) - super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, quant_config, - self.output_sizes) + super().__init__(input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config) def weight_loader(self, param: Parameter, @@ -351,6 +355,26 @@ def weight_loader(self, output_dim = getattr(param, "output_dim", None) # Special case for AQLM codebooks. is_metadata = getattr(param, "is_metadata", False) + + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # If a parameter has defined a shard_splitter to be used for + # the weight, it should be applied before the weight is + # loaded/copied to the parameter. The shard_splitter applies + # logic by using the loaded_shard_id to ensure that the loaded + # param is loaded to the correct location + # within the parameter defined by the linear method. + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # Special case for Fp8 scales. fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", None) @@ -411,6 +435,13 @@ def weight_loader(self, shard_size = loaded_weight.shape[0] shard_offset = loaded_shard_id * shard_size param_data = param_data.narrow(0, shard_offset, shard_size) + + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths", None) + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + # Special case for Fp8 scales. elif fp8_scales_shard_indexer is not None: param_data, loaded_weight = fp8_scales_shard_indexer( @@ -424,6 +455,14 @@ def weight_loader(self, "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") + if fp8_scales_shard_indexer is None: + if len(param_data.shape) == 0: + param_data = param_data.reshape(1) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + # UPSTREAM SYNC: needed for LazyCompressedParameter self.loaded_shards.add(loaded_shard_id) assert param_data.shape == loaded_weight.shape @@ -463,17 +502,15 @@ class QKVParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. """ - def __init__( - self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -495,14 +532,19 @@ def __init__( input_size = self.hidden_size output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size - output_sizes = [ - self.num_heads * tp_size * self.head_size, - self.num_kv_heads * tp_size * self.head_size, - self.num_kv_heads * tp_size * self.head_size + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj ] - super().__init__(input_size, output_size, bias, False, skip_bias_add, - params_dtype, quant_config, output_sizes) + super().__init__(input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config) def weight_loader(self, param: Parameter, @@ -512,6 +554,26 @@ def weight_loader(self, output_dim = getattr(param, "output_dim", None) # Special case for AQLM codebooks. is_metadata = getattr(param, "is_metadata", False) + + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # If a parameter has defined a shard_splitter to be used for + # the weight, it should be applied before the weight is + # loaded/copied to the parameter. The shard_splitter applies + # logic by using the loaded_shard_id to ensure that the loaded + # param is loaded to the correct location + # within the parameter defined by the linear method. + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # Special case for Fp8 scales. fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", None) @@ -550,6 +612,8 @@ def weight_loader(self, tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. if output_dim is not None: if loaded_shard_id == "q": shard_offset = 0 @@ -589,6 +653,12 @@ def weight_loader(self, shard_index = ["q", "k", "v"].index(loaded_shard_id) param_data = param_data.narrow(0, shard_index * shard_size, shard_size) + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths", None) + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + # Special case for Fp8 scales. elif fp8_scales_shard_indexer is not None: param_data, loaded_weight = fp8_scales_shard_indexer( @@ -600,6 +670,13 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "QKVParallelLinear, assume the weight is the same " "for all partitions.") + + if len(param_data.shape) == 0: + param_data = param_data.reshape(1) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -639,17 +716,15 @@ class RowParallelLinear(LinearBase): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -659,16 +734,15 @@ def __init__( # Divide the weight matrix along the last dimension. self.tp_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, self.tp_size) - # All the linear layer supports quant method. assert self.quant_method is not None - self.quant_method.create_weights(self, - self.input_size_per_partition, - [self.output_size], - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) - + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=[self.output_size], + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") @@ -696,12 +770,16 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + # Special case for Fp8 scales. elif fp8_scales_shard_indexer is not None: param_data, loaded_weight = fp8_scales_shard_indexer(param_data, loaded_weight, shard_id=0) + if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index f938e7d37ec5f..7b9abe1b629a1 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,6 +4,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig) from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig) from vllm.model_executor.layers.quantization.fp8 import Fp8Config @@ -27,6 +29,7 @@ "gptq_marlin": GPTQMarlinConfig, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, + "sparseml": CompressedTensorsConfig, } diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py new file mode 100644 index 0000000000000..19e464bd64325 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -0,0 +1,151 @@ +from typing import Any, Dict, List, Optional + +import torch + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme, CompressedTensorsW8A8StaticTensor) + + +class CompressedTensorsConfig(QuantizationConfig): + + def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]): + self.ignore = ignore + self.layer_quant_details = layer_quant_details + + def get_linear_method(self) -> "CompressedTensorsLinearMethod": + return CompressedTensorsLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16] + + # Need to figure it out + def get_min_capability(self) -> int: + return 60 + + def get_name(self) -> str: + return "compressed_tensors" + + def get_quant_method( + self, layer: torch.nn.Module + ) -> Optional["CompressedTensorsLinearMethod"]: + if isinstance(layer, LinearBase): + return CompressedTensorsLinearMethod(self) + return None + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": + layer_quant_details: Dict[str, Any] = dict() + ignore: List[str] = config.get("ignore", None) + + for key, quant_config in config["config_groups"].items(): + targets = quant_config.get("targets") + for target in targets: + layer_quant_details[target] = {} + layer_quant_details[target]["weight"] = quant_config.get( + "weights") + layer_quant_details[target]["input"] = quant_config.get( + "input_activations") + + return cls(layer_quant_details=layer_quant_details, ignore=ignore) + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + def _get_schema(self, weight_quant: Dict, input_quant: Dict): + # TODO: Refactor as additional cases are supported + + weight_bit = weight_quant.get("num_bits") + input_bit = input_quant.get("num_bits") + + weight_strategy = weight_quant.get("strategy") + input_strategy = input_quant.get("strategy") + + weight_symmetric = weight_quant.get("symmetric") + input_symmetric = input_quant.get("symmetric") + + is_8_bits = weight_bit == input_bit == 8 + is_tensor = weight_strategy == input_strategy == "tensor" + is_symmetric = weight_symmetric and input_symmetric + + if is_8_bits and is_tensor and is_symmetric and \ + torch.cuda.is_available(): + # CompressedTensorsW8A8StaticTensor only supports CUDA path for + # now. + return CompressedTensorsW8A8StaticTensor() + raise NotImplementedError( + "Scheme not supported. Only CUDA, 8-bit static symmtetric " + "per tensor quantization is currently supported") + + def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": + + # TODO: update with matching function from `compressed_tensors` + layer_type_name = None + layer_name_class = type(layer).__name__.lower() + for target in self.layer_quant_details: + if target.lower() in layer_name_class: + layer_type_name = target + break + if layer_type_name is None: + raise ValueError(f"Could not matching target for layer {layer}") + + layer_quant_details: Dict[str, Any] = self.layer_quant_details.get( + layer_type_name, None) + if layer_quant_details is None: + raise ValueError( + f"Could not find quantization details for {layer}.") + + return self._get_schema(weight_quant=layer_quant_details["weight"], + input_quant=layer_quant_details["input"]) + + +class CompressedTensorsLinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: CompressedTensorsConfig): + self.quantization_config = quantization_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """ + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. + """ + weight_loader = extra_weight_attrs.get("weight_loader") + + scheme = self.quantization_config.get_scheme(layer=layer) + scheme.create_weights( + layer=layer, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader) + + layer.scheme = scheme + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. + """ + + if bias is not None: + raise ValueError("bias is not supported for this linear method") + + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py new file mode 100644 index 0000000000000..831905b63e2c9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -0,0 +1,5 @@ +from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401 +from .compressed_tensors_unquantized import ( # noqa: F401 + CompressedTensorsUnquantized) +from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501 + CompressedTensorsW8A8StaticTensor) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py new file mode 100644 index 0000000000000..3a5904208656e --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -0,0 +1,33 @@ +from abc import ABC, abstractmethod + +import torch + +__all__ = ["CompressedTensorsScheme"] + + +class CompressedTensorsScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by CompressedTensors. + """ + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: toch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py new file mode 100644 index 0000000000000..0cfac13d1ca25 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py @@ -0,0 +1,39 @@ +from typing import Callable, List + +import torch +import torch.nn.functional as F +from torch.nn import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.utils import set_weight_attrs + +__all__ = ["CompressedTensorsUnquantized"] + + +class CompressedTensorsUnquantized(CompressedTensorsScheme): + """ + Implements the scheme for all layers which are ignored + in the CompressedTensors config. The input and loaded weight are used + in a linear transformation. + """ + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + device="cuda", + dtype=params_dtype), + requires_grad=False) + + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, {"weight_loader": weight_loader}) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + weight = layer.weight + return F.linear(x, weight) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py new file mode 100644 index 0000000000000..d16e570d12202 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -0,0 +1,119 @@ +from typing import Callable, List, Tuple, Union + +import torch +from torch.nn import Parameter + +from vllm import _custom_ops as custom_ops +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.utils import set_weight_attrs + +__all__ = ["CompressedTensorsW8A8StaticTensor"] + + +class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): + + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + assert isinstance(shard_id, str) + qkv_idxs = {"q": 0, "k": 1, "v": 2} + assert shard_id in qkv_idxs + return qkv_idxs[shard_id] + + def scales_shard_splitter( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int], + logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shard_id = self._shard_id_as_int(shard_id) + offset = sum(logical_widths[:shard_id]) + size = logical_widths[shard_id] + # update loaded weight with copies for broadcast. + loaded_weight = loaded_weight.repeat(size) + return param[offset:offset + size], loaded_weight + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + # TODO: remove zero_point parameters once the configs given remove them + + # Note on input/weight scales and zero_points + # + # When the scales have a single value, it is required that they be + # on the CPU for 2 reasons, + # 1. Performance: + # When the scales (input_scale/weight_scales) have only a single + # value, we perform a scalar broadcast of that value during the + # quant/dequant operations. The "quant" and the "gemm+dequant" + # kernels accept the Scalar by-value. These tensors are allocated + # on the CPU in order to avoid the GPU-to-CPU copy when passing + # by-value. + # + # 2. CUDA Graphs: + # CUDA Graphs don't support GPU-to-CPU copy operations during + # stream capture. + # + # TODO: zero-points are not supported yet. But we expect a similar + # pattern. + + is_tensor_partitioned = len(output_partition_sizes) != 1 + weight_scale_dim = sum( + output_partition_sizes) if is_tensor_partitioned else 1 + weight_scale_device = "cpu" if weight_scale_dim == 1 else "cuda" + + input_scale = Parameter(torch.empty(1, + device="cpu", + dtype=torch.float32), + requires_grad=False) + input_zero_point = Parameter(torch.empty(1, + device="cpu", + dtype=torch.int8), + requires_grad=False) + + weight_scale = Parameter(torch.empty(weight_scale_dim, + device=weight_scale_device, + dtype=torch.float32), + requires_grad=False) + weight_zero_point = Parameter(torch.empty(1, + device="cpu", + dtype=torch.int8), + requires_grad=False) + + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + requires_grad=False) + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + + set_weight_attrs(weight, {"weight_loader": weight_loader}) + + layer.register_parameter("input_scale", input_scale) + set_weight_attrs(input_scale, {"weight_loader": weight_loader}) + layer.register_parameter("input_zero_point", input_zero_point) + set_weight_attrs(input_zero_point, {"weight_loader": weight_loader}) + layer.register_parameter("weight_scale", weight_scale) + set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) + set_weight_attrs( + weight_scale, { + "shard_splitter": self.scales_shard_splitter, + "logical_widths": output_partition_sizes + }) + layer.register_parameter("weight_zero_point", weight_zero_point) + set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader}) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + weight = layer.weight + weight_scale = layer.weight_scale + act_scale = layer.input_scale + + # Input quantize + x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item()) + + return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, + weight_scale, x.dtype) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 4e826256bdba7..ecad5041099d8 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -135,6 +135,13 @@ def get_quant_config(model_config: ModelConfig, # Read the quantization config from the HF model config, if available. hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) + if hf_quant_config is None: + compression_config = getattr(model_config.hf_config, + "compression_config", None) + if compression_config is not None: + hf_quant_config = compression_config.get("quantization_config", + None) + if hf_quant_config is not None: return quant_cls.from_config(hf_quant_config) model_name_or_path = model_config.model diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f43a40a0bfd34..086f9294c4f1c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -62,11 +62,12 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, bias=bias, quant_config=quant_config) if hidden_act != "silu": @@ -120,16 +121,16 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=quant_config, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, bias=bias, quant_config=quant_config, ) @@ -263,8 +264,10 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) + LlamaDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)