Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Misc ] Refactor w8a8 to use process_weights_after_load (Simplify Weight Loading) #5940

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
fed75ad
remove scales shard indexer fp8 in row parallel linear
Jun 27, 2024
8a0d664
stash
Jun 27, 2024
7d8e9cd
removed from column parallel linear too
Jun 27, 2024
2803761
format
Jun 27, 2024
b35a097
Merge branch 'remove-fp8-shard-indexer-row' into unify-fp8-int8-scales
Jun 27, 2024
34884c4
refactored w8a8 loading to use process_weights_after_load, which simp…
Jun 27, 2024
1429b82
nits
Jun 27, 2024
d8b639c
Update compressed_tensors_scheme.py
robertgshaw2-neuralmagic Jun 27, 2024
6117249
clean up PR
Jun 27, 2024
ef01048
updated name
Jun 27, 2024
76d113a
updated comment
Jun 27, 2024
45aefff
format
Jun 27, 2024
057dca7
reorder to make pr easier to read
Jun 27, 2024
d9f5512
format
Jun 27, 2024
c778ad3
remove comment
Jun 27, 2024
195650e
format
Jun 27, 2024
7e28476
fix tests
Jun 28, 2024
198cf45
format
Jun 28, 2024
bcb1a07
Merge branch 'main' into unify-fp8-int8-scales
Jun 28, 2024
3fbacd0
fix test
Jun 28, 2024
ca1d341
Merge branch 'main' into unify-fp8-int8-scales
robertgshaw2-neuralmagic Jun 30, 2024
9418b0b
fix test
robertgshaw2-neuralmagic Jun 30, 2024
8994cfe
push up current state
robertgshaw2-neuralmagic Jun 30, 2024
32c8eb2
rename per cody's request
Jun 30, 2024
8e14af6
merge with phi pr, add test, format
Jun 30, 2024
10f1889
Update compressed_tensors_w8a8.py
robertgshaw2-neuralmagic Jun 30, 2024
123c930
Merge branch 'upstream-main' into unify-fp8-int8-scales
robertgshaw2-neuralmagic Jun 30, 2024
3857689
Merge branch 'upstream-main' into unify-fp8-int8-scales
robertgshaw2-neuralmagic Jun 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationType)


@pytest.mark.parametrize("model_args", [
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor"),
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel"),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
QuantizationType.INT, 2560),
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
QuantizationType.INT, 2560),
])
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy = model_args
model_path, strategy, quant_type, shape_0 = model_args
with vllm_runner(model_path, enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
Expand All @@ -34,17 +38,23 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
CompressedTensorsLinearMethod)
assert isinstance(down_proj.quant_method,
CompressedTensorsLinearMethod)

assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)

assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.weight.dtype is torch.int8
assert o_proj.weight.dtype is torch.int8
assert gate_up_proj.weight.dtype is torch.int8
expected_type = (torch.int8 if quant_type == QuantizationType.INT else
torch.float8_e4m3fn)

assert qkv_proj.weight.dtype is expected_type
assert o_proj.weight.dtype is expected_type
assert gate_up_proj.weight.dtype is expected_type

if qkv_proj.scheme.strategy == "tensor":
assert qkv_proj.weight_scale.shard_splitter is not None
assert qkv_proj.weight_scale.logical_widths is not None
# Make sure it is a channelwise buffer
# After running process_weights_after_loading
assert len(qkv_proj.weight_scale.shape) == 2
assert qkv_proj.weight_scale.shape[0] == shape_0
assert qkv_proj.weight_scale.shape[1] == 1
assert qkv_proj.weight_scale.dtype is torch.float32
assert qkv_proj.input_scale.dtype is torch.float32


Expand Down
17 changes: 17 additions & 0 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,23 @@
from vllm._custom_ops import scaled_fp8_quant
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod

MODELS = [
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8",
"nm-testing/Phi-3-mini-128k-instruct-FP8",
]


@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
def test_model_load_and_run(vllm_runner, model: str):
with vllm_runner(model) as llm:
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
outputs = llm.generate_greedy(prompts=["Hello my name is"],
max_tokens=10)
print(outputs[0][1])


@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
Expand Down
110 changes: 39 additions & 71 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,29 @@ def adjust_bitsandbytes_shard(param: Parameter,
return quantized_size, quantized_offset


def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
"""For fused modules (QKV and MLP) we have an array of length
N that holds 1 scale for each "logical" matrix. So the param
is an array of length N. The loaded_weight corresponds to
one of the shards on disk. Here, we slice the param based on
the shard_id for loading.
"""
qkv_idxs = {"q": 0, "k": 1, "v": 2}

if isinstance(shard_id, str):
shard_id = qkv_idxs[shard_id]
elif not isinstance(shard_id, int):
raise ValueError(f"Unknown Shard Id {shard_id}")

# AutoFP8 scales do not have a shape
# compressed-tensors scales do have a shape
if len(loaded_weight.shape) != 0:
assert loaded_weight.shape[0] == 1
loaded_weight = loaded_weight[0]

return param[shard_id], loaded_weight


class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""

Expand Down Expand Up @@ -358,37 +381,15 @@ def weight_loader(self,
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)

param_shard_splitter = getattr(param, "shard_splitter", None)

if output_dim is not None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support output_dim != None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# If a parameter has defined a shard_splitter to be used for
# the weight, it should be applied before the weight is
# loaded/copied to the parameter. The shard_splitter applies
# logic by using the loaded_shard_id to ensure that the loaded
# param is loaded to the correct location
# within the parameter defined by the linear method.
if loaded_shard_id is None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support loaded_shard_id == None and "
"shard_splitter != None for a parameter. Please open an issue."
)

# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)

if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv/mlp).
if output_dim is None:
# If fp8 + scale, need to send to each shard.
if fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)
if needs_scalar_to_array is not None:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, 0)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
Expand Down Expand Up @@ -450,15 +451,9 @@ def weight_loader(self,
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)

# If a param_shard_splitter is defined by the LinearMethod, use it.
elif param_shard_splitter is not None:
logical_widths = getattr(param, "logical_widths", None)
param_data, loaded_weight = param_shard_splitter(
param_data, loaded_weight, loaded_shard_id, logical_widths)

# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
# Special case for per-tensor scales in fused case.
elif needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, loaded_shard_id)

else:
Expand Down Expand Up @@ -548,36 +543,15 @@ def weight_loader(self,
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)

param_shard_splitter = getattr(param, "shard_splitter", None)

if output_dim is not None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support output_dim != None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# If a parameter has defined a shard_splitter to be used for
# the weight, it should be applied before the weight is
# loaded/copied to the parameter. The shard_splitter applies
# logic by using the loaded_shard_id to ensure that the loaded
# param is loaded to the correct location
# within the parameter defined by the linear method.
if loaded_shard_id is None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support loaded_shard_id == None and "
"shard_splitter != None for a parameter. Please open an issue."
)

# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
# Special case for per-tensor scales in fused case.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)

if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv/mlp).
if output_dim is None:
# If fp8 + scale, need to send to each shard.
if fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)
if needs_scalar_to_array is not None:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, 0)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
Expand Down Expand Up @@ -667,15 +641,9 @@ def weight_loader(self,
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size,
shard_size)
# If a param_shard_splitter is defined by the LinearMethod, use it.
elif param_shard_splitter is not None:
logical_widths = getattr(param, "logical_widths", None)
param_data, loaded_weight = param_shard_splitter(
param_data, loaded_weight, loaded_shard_id, logical_widths)

# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
# Special case for per-tensor scales in fused case.
elif needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, loaded_shard_id)
else:
ignore_warning = getattr(param, "ignore_warning", False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
return layer.scheme.process_weights_after_loading(layer)

def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,11 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):

"""
raise NotImplementedError

@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
"""
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise NotImplementedError
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
in a linear transformation.
"""

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __init__(self,
raise ValueError(
"group_size must be given when using strategy group")

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass

def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,70 +15,63 @@ class CompressedTensorsW8A8(CompressedTensorsScheme):
def __init__(self, strategy: str):
self.strategy = strategy

def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id

assert isinstance(shard_id, str)
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]

def scales_shard_splitter(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int],
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shard_id = self._shard_id_as_int(shard_id)
offset = sum(logical_widths[:shard_id])
size = logical_widths[shard_id]
# update loaded weight with copies for broadcast.
loaded_weight = loaded_weight.repeat(size)
return param[offset:offset + size], loaded_weight
# Cutlass kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), we convert to the per-channel case.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if (self.strategy == QuantizationStrategy.TENSOR
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could just move this to the per tensor scheme

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strategy refers to the weight quantization here

and len(self.logical_widths) > 1):

# Load the N per-tensor scales into the channelwise buffer.
weight_scale_channel = torch.empty(
(sum(self.logical_widths), 1),
dtype=torch.float32,
device=layer.weight_scale.device)
start = 0
for idx, logical_width in enumerate(self.logical_widths):
end = start + logical_width
weight_scale_channel[start:end, :] = layer.weight_scale[idx]
start = end

layer.weight_scale = Parameter(weight_scale_channel,
requires_grad=False)

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.logical_widths = output_partition_sizes

is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum(output_partition_sizes) if (
is_tensor_partitioned
or self.strategy == QuantizationStrategy.CHANNEL) else 1

shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
# WEIGHT SCALE
shape: Union[Tuple[int], Tuple[int, int]]
if self.strategy == QuantizationStrategy.CHANNEL:
shape = (weight_scale_dim, 1)
shape = (sum(self.logical_widths), 1)
else:
shape = (len(self.logical_widths), )

weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
requires_grad=False)

layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
if self.strategy == QuantizationStrategy.CHANNEL:
set_weight_attrs(weight_scale, {
"weight_loader": weight_loader,
"output_dim": 0,
})
else:
set_weight_attrs(weight_scale, {
"weight_loader": weight_loader,
"needs_scalar_to_array": True,
})

# WEIGHT
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)

layer.register_parameter("weight", weight)
set_weight_attrs(
weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
"logical_widths": output_partition_sizes
})

# Don't need a shard_splitter for channel-wise quantization
# Use the default loading method
if self.strategy == QuantizationStrategy.CHANNEL:
set_weight_attrs(weight_scale, {
"output_dim": 0,
})
else:
set_weight_attrs(
weight_scale, {
"logical_widths": output_partition_sizes,
"shard_splitter": self.scales_shard_splitter,
})
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
})
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __init__(self,
raise ValueError(
"group_size must be given when using strategy group")

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass

def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
Expand Down
Loading
Loading