Skip to content

Add Float8 support for AQT tensor parallel #1003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import torch
from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase
from torch.testing._internal.common_utils import run_tests
from torchao.quantization import int8_weight_only
from torchao.quantization import int8_weight_only, float8_weight_only

class TestAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
pass
class TestInt8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
QUANT_METHOD_FN = staticmethod(int8_weight_only)
copy_tests(TorchAOTensorParallelTestCase, TestInt8woAffineQuantizedTensorParallel, "int8wo_tp")


copy_tests(TorchAOTensorParallelTestCase, TestAffineQuantizedTensorParallel, "aqt_tp")
# Run only on H100
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
class TestFloat8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
QUANT_METHOD_FN = staticmethod(float8_weight_only)
copy_tests(TorchAOTensorParallelTestCase, TestFloat8woAffineQuantizedTensorParallel, "fp8wo_tp")

if __name__ == "__main__":
run_tests()
46 changes: 40 additions & 6 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,20 +1094,31 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)
if func is aten.clone.default:
elif func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)
if func is aten.t.default:
elif func is aten.t.default:
"""we don't need to repack the weight and just rely on external
shape being changed and record the status of transpose/no-transpose
"""
args[0].transposed = not args[0].transposed
return return_and_correct_aliasing(func, args, kwargs, args[0])

raise NotImplementedError(
f"Float8AQTLayout dispatch: attempting to run {func}, this is not supported"
)
elif func is aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
)
elif dim == 1:
assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}"
return Float8AQTLayout(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self.layout_type)
else:
raise NotImplementedError(f"Float8AQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported")
else:
raise NotImplementedError(
f"Float8AQTLayout dispatch: attempting to run {func}, this is not supported"
)

__torch_function__ = torch._C._disabled_torch_function_impl

Expand Down Expand Up @@ -1644,6 +1655,28 @@ def _linear_fp8_act_fp8_weight_impl(
use_fast_accum=scaled_mm_config.use_fast_accum,
).reshape(out_shape)

def _linear_fp_act_fp8_weight_check(
input_tensor: Union[torch.Tensor, AffineQuantizedTensor],
weight_tensor: Union[torch.Tensor, AffineQuantizedTensor],
bias: Optional[torch.Tensor],
) -> bool:
return (
# input is native float tensor
not is_traceable_wrapper_subclass(input_tensor) and
input_tensor.is_floating_point() and
# weight is float8 quantized affine quantized tensor
isinstance(weight_tensor, AffineQuantizedTensor) and
isinstance(weight_tensor.layout_type, Float8LayoutType)
and weight_tensor.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
and (weight_tensor.shape == weight_tensor.block_size or _is_rowwise_scaled(weight_tensor))
)

def _linear_fp_act_fp8_weight_impl(
input_tensor: torch.Tensor,
weight_tensor: AffineQuantizedTensor,
bias: Optional[torch.Tensor],
):
return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias)

def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias):
return (
Expand Down Expand Up @@ -1694,6 +1727,7 @@ def _register_aqt_quantized_linear_dispatches():
(_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl),
(_linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl),
(_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl),
(_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl),
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),
(_linear_f16_act_floatx_weight_check, _linear_f16_act_floatx_weight_impl),
Expand Down
Loading