Skip to content

Commit 72d2518

Browse files
jerryzh168kwen2501
andauthored
Supporting tensor parallelism for int8 weight only quant (#939)
* [WIP] Supporting tensor parallelism for int8 weight only quant Summary: following https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/tensor_parallel.py we can support tensor parallelism for int8 weight only quant, this is needed for torchchat Test Plan: python test/dtypes/test_affine_quantized_tensor_parallel.py Reviewers: Subscribers: Tasks: Tags: * implement tp for aqt * fixes * import fix * remove cpu test * fix * fix * fix test * device * change transpose impl * Skip compiled TP test for torch version < 2.5 * version util * fix * fix version --------- Co-authored-by: Ke Wen <kw2501@meta.com>
1 parent 63cb7a9 commit 72d2518

File tree

8 files changed

+210
-32
lines changed

8 files changed

+210
-32
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,5 +135,6 @@ def test_print_quantized_module(self, apply_quant):
135135

136136
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
137137

138+
138139
if __name__ == "__main__":
139140
run_tests()
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase
2+
from torch.testing._internal.common_utils import run_tests
3+
from torchao.quantization import int8_weight_only
4+
5+
class TestAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
6+
pass
7+
8+
9+
copy_tests(TorchAOTensorParallelTestCase, TestAffineQuantizedTensorParallel, "aqt_tp")
10+
11+
if __name__ == "__main__":
12+
run_tests()

torchao/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@
3333
quantize_,
3434
)
3535
from . import dtypes
36+
from . import testing
3637

3738
__all__ = [
3839
"dtypes",
3940
"autoquant",
4041
"quantize_",
42+
"testing",
4143
]
4244

4345
# test-pytorchbot

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
find_multiple,
3939
TorchAOBaseTensor,
4040
TORCH_VERSION_AT_LEAST_2_5,
41-
_is_float8_type
41+
_is_float8_type,
42+
fill_defaults,
4243
)
4344
import logging
4445

@@ -603,13 +604,25 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
603604
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
604605
)
605606

606-
if func is aten.t.default:
607+
elif func is aten.t.default:
607608
tensor = args[0]
608609
new = tensor.__class__(
609-
tensor.int_data.view(tensor.shape[::-1]), tensor.scale, tensor.zero_point, tensor.layout_type
610+
tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor.layout_type
610611
)
611612
return return_and_correct_aliasing(func, args, kwargs, new)
612613

614+
elif func is aten.slice.Tensor:
615+
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
616+
if dim == 0:
617+
return return_and_correct_aliasing(
618+
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
619+
)
620+
elif dim == 1:
621+
assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}"
622+
return PlainAQTLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.zero_point.view(-1), self.layout_type)
623+
else:
624+
raise NotImplementedError(f"PlainAQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported")
625+
613626
raise NotImplementedError(
614627
f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported"
615628
)
@@ -1776,6 +1789,39 @@ def _(func, types, args, kwargs):
17761789
)
17771790
return return_and_correct_aliasing(func, args, kwargs, new)
17781791

1792+
@implements(aten.slice.Tensor)
1793+
def _(func, types, args, kwargs):
1794+
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
1795+
assert step == 1
1796+
assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}"
1797+
if end >= self.shape[dim]:
1798+
end = self.shape[dim]
1799+
shape = list(self.shape)
1800+
shape[dim] = end - start
1801+
block_size = self.block_size
1802+
assert len(block_size) == 2, f"Slice only works for 2d block_size right now, got: {block_size}"
1803+
# with slice, some shape dimension might be smaller than block_size dimension, so
1804+
# we need to make sure there is no overflow
1805+
block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1]))
1806+
new = self.__class__(aten.slice.Tensor(self.layout_tensor, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride())
1807+
return return_and_correct_aliasing(func, args, kwargs, new)
1808+
1809+
# this is needed for DTensor.from_local() and for flattening tensor
1810+
@implements(aten.view.default)
1811+
def _(func, types, args, kwargs):
1812+
self, shape = args
1813+
1814+
if tuple(self.shape) == tuple(shape):
1815+
return self.__class__(self.layout_tensor, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride())
1816+
1817+
if len(shape) == 1 and shape[0] == -1:
1818+
assert len(self.block_size) == 2 and self.block_size[0] == 1
1819+
block_size = (self.block_size[1],)
1820+
return self.__class__(self.layout_tensor, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride())
1821+
1822+
raise ValueError(f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]")
1823+
1824+
17791825
to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx
17801826
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
17811827
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx

torchao/testing/utils.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
import copy
44
import torch
55
import torchao
6+
import os
67

78
from torch.testing._internal import common_utils
89
from torchao.dtypes import AffineQuantizedTensor
910
from torchao.dtypes import to_affine_quantized_intx
1011
from torchao.quantization.quant_primitives import MappingType
12+
from torchao.quantization import quantize_, int8_weight_only
13+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1114

1215
"""
1316
How to use:
@@ -213,10 +216,122 @@ def test_linear_compile(self, device, dtype):
213216
lp_res = torch.compile(l)(hp_act_tensor)
214217
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)
215218

219+
import torch.distributed as dist
220+
from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh
221+
from torch.testing._internal.distributed._tensor.common_dtensor import (
222+
DTensorTestBase,
223+
with_comms,
224+
NUM_DEVICES,
225+
)
226+
227+
class TorchAOTensorParallelTestCase(DTensorTestBase):
228+
"""Basic test case for tensor subclasses
229+
"""
230+
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
231+
232+
TENSOR_SUBCLASS = AffineQuantizedTensor
233+
QUANT_METHOD_FN = staticmethod(int8_weight_only)
234+
QUANT_METHOD_KWARGS = {}
216235

236+
@staticmethod
237+
def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
238+
"""
239+
Shard linear layer of the model in column-wise fashion
240+
"""
241+
# Column-wise is wrt to A^T, so for A it is row-wise.
242+
# Number of rows per rank
243+
orig_weight = m.linear.weight
244+
n_local_rows = orig_weight.size(0) // mesh.size()
245+
rank = mesh.get_local_rank()
246+
local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :]
247+
# Construct DTensor from local shard
248+
dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)])
249+
# Replace parameter in module
250+
m.linear.weight = torch.nn.Parameter(
251+
dtensor, requires_grad=False
252+
)
253+
return m
254+
255+
@staticmethod
256+
def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
257+
"""
258+
Shard linear layer of the model in row-wise fashion
259+
"""
260+
# Row-wise is wrt to A^T, so for A it is column-wise.
261+
# Number of rows per rank
262+
orig_weight = m.linear.weight
263+
n_local_cols = orig_weight.size(1) // mesh.size()
264+
rank = mesh.get_local_rank()
265+
local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols]
266+
# Construct DTensor from local shard
267+
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)])
268+
# Replace parameter in module
269+
m.linear.weight = torch.nn.Parameter(
270+
dtensor, requires_grad=False
271+
)
272+
return m
273+
274+
def quantize(self, m: torch.nn.Module) -> torch.nn.Module:
275+
"""
276+
Quantize the model
277+
"""
278+
quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS))
279+
return m
280+
281+
@common_utils.parametrize("dtype", COMMON_DTYPES)
282+
@with_comms
283+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
284+
def test_tp(self, dtype):
285+
device = "cuda"
286+
# To make sure different ranks create the same module
287+
torch.manual_seed(5)
288+
289+
class M(torch.nn.Module):
290+
def __init__(self, in_features, out_features, **kwargs) -> None:
291+
super().__init__(**kwargs)
292+
self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda")
293+
294+
def forward(self, x: torch.Tensor) -> torch.Tensor:
295+
return self.linear(x)
296+
297+
# Get rank and device
298+
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")
299+
300+
# Original model
301+
proj_up = M(1024, 2048).to(device).to(dtype)
302+
proj_dn = M(2048, 1024).to(device).to(dtype)
303+
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
304+
y = proj_dn(proj_up(example_input))
305+
306+
# Quantize the model
307+
up_quant = self.quantize(proj_up)
308+
dn_quant = self.quantize(proj_dn)
309+
y_q = dn_quant(up_quant(example_input))
310+
311+
mesh = self.build_device_mesh()
312+
# Shard the models
313+
up_dist = self.colwise_shard(up_quant, mesh)
314+
dn_dist = self.rowwise_shard(dn_quant, mesh)
315+
316+
# We need to turn inputs into DTensor form as well -- just a format change
317+
input_dtensor = DTensor.from_local(
318+
example_input, mesh, [Replicate()]
319+
)
320+
321+
y_d = dn_dist(up_dist(input_dtensor))
322+
323+
if not TORCH_VERSION_AT_LEAST_2_5:
324+
# Need torch 2.5 to support compiled tensor parallelism
325+
return
326+
327+
up_compiled = torch.compile(up_dist)
328+
y_up = up_compiled(input_dtensor)
329+
dn_compiled = torch.compile(dn_dist)
330+
y_dn = dn_compiled(y_up)
217331

218332
common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
219333
common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)
334+
common_utils.instantiate_parametrized_tests(TorchAOTensorParallelTestCase)
220335

221336
if __name__ == "__main__":
222337
unittest.main()

torchao/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,30 @@ def _get_to_kwargs(self, *args, **kwargs):
493493
}
494494
return kwargs
495495

496+
def fill_defaults(args, n, defaults_tail):
497+
"""
498+
__torch_dispatch__ doesn't guarantee the number of arguments you are
499+
passed (e.g., defaulted arguments are not passed); but usually it is
500+
convenient to pad out the arguments list with defaults. This function
501+
helps you do that.
502+
Args:
503+
args: the list of positional arguments passed to __torch_dispatch__
504+
n: the number of arguments you are expecting to get
505+
defaults_tail: default values for the arguments, starting from the
506+
end of the list
507+
Example:
508+
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
509+
[1, 2, 3, 4, 5]
510+
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
511+
[1, 2, 3, None, None]]
512+
"""
513+
if n - len(defaults_tail) > len(args):
514+
raise RuntimeError("not enough defaults to fill arguments")
515+
r = list(args)
516+
for i in range(len(args), n):
517+
r.append(defaults_tail[i - n + len(defaults_tail)])
518+
return r
519+
496520

497521
## Deprecated, will be deleted in the future
498522
def _torch_version_at_least(min_version):

tutorials/developer_api_guide/my_dtype_tensor_subclass.py

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,36 +25,13 @@
2525
LayoutType,
2626
PlainLayoutType,
2727
)
28-
from torchao.utils import TorchAOBaseTensor
28+
from torchao.utils import (
29+
TorchAOBaseTensor,
30+
fill_defaults,
31+
)
2932

3033
aten = torch.ops.aten
3134

32-
# TODO: move to torchao/utils.py
33-
def fill_defaults(args, n, defaults_tail):
34-
"""
35-
__torch_dispatch__ doesn't guarantee the number of arguments you are
36-
passed (e.g., defaulted arguments are not passed); but usually it is
37-
convenient to pad out the arguments list with defaults. This function
38-
helps you do that.
39-
Args:
40-
args: the list of positional arguments passed to __torch_dispatch__
41-
n: the number of arguments you are expecting to get
42-
defaults_tail: default values for the arguments, starting from the
43-
end of the list
44-
Example:
45-
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
46-
[1, 2, 3, 4, 5]
47-
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
48-
[1, 2, 3, None, None]]
49-
"""
50-
if n - len(defaults_tail) > len(args):
51-
raise RuntimeError("not enough defaults to fill arguments")
52-
r = list(args)
53-
for i in range(len(args), n):
54-
r.append(defaults_tail[i - n + len(defaults_tail)])
55-
return r
56-
57-
5835
###############################
5936
# Base Layout Tensor Subclass #
6037
###############################
@@ -327,7 +304,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
327304
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
328305
)
329306
elif dim == 1:
330-
return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1, 1), self.transposed, self.layout_type)
307+
return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.transposed, self.layout_type)
331308
else:
332309
raise NotImplementedError(f"PlainMyDTypeLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported")
333310
elif func is aten.t.default:

tutorials/developer_api_guide/tensor_parallel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from torch.distributed import DeviceMesh
66
from torch.distributed.tensor import DTensor, Replicate, Shard, Placement
77
from torch.utils._python_dispatch import return_and_correct_aliasing
8-
from my_dtype_tensor_subclass import MyDTypeTensor, fill_defaults
8+
from my_dtype_tensor_subclass import MyDTypeTensor
9+
from torchao.utils import fill_defaults
910

1011
# a tensor subclass that supports tensor parallelism with DTensor
1112
class MyDTypeTensorTP(MyDTypeTensor):

0 commit comments

Comments
 (0)