Skip to content

Commit

Permalink
Call narrow only for TensorCoreTiledLayout (#1207)
Browse files Browse the repository at this point in the history
* Call narrow only for TensorCoreTiledLayout only

Summary:
att, previously in #914 we added narrow op for all layout,
the introduced narrow op breaks the pattern for int8 dynamic activation int4 weight quant for
executorch, this PR guarded narrow op for tensor core tiled layout only

If similar things coming up in the future we can factor this into a proper API for Layout or TensorImpl

Test Plan:
python test/test_integration.py -k test_export

Reviewers:

Subscribers:

Tasks:

Tags:

* enable test

* version

* skip aoti

* version update

* skip aoti
  • Loading branch information
jerryzh168 authored Nov 12, 2024
1 parent 2ba1a61 commit ccd883b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 25 deletions.
50 changes: 30 additions & 20 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int4_weight,
quantize_,
_replace_with_custom_fn_if_matches_filter,
)
Expand Down Expand Up @@ -137,6 +138,12 @@ def _int4wo_api(mod):
else:
change_linear_weights_to_int4_woqtensors(mod)

def _int8da_int4w_api(mod):
quantize_(mod, int8_dynamic_activation_int4_weight(), set_inductor_config=False)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)


# TODO: use this to reduce the number of tests
TENSOR_SUBCLASS_APIS = [
_int8wo_api,
Expand Down Expand Up @@ -781,7 +788,7 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
Expand Down Expand Up @@ -973,11 +980,11 @@ def test_weight_only_groupwise_embedding_quant(self):
group_size = 64
m = nn.Embedding(4096, 128)
input = torch.randint(0, 4096, (1, 6))

quantize_(m, int8_weight_only(group_size=group_size), filter_fn=lambda x, *args: isinstance(x, nn.Embedding))
y_q = m(input)
y_ref = m.weight.dequantize()[input]

sqnr = compute_error(y_ref, y_q)

self.assertGreater(sqnr, 45.0)
Expand Down Expand Up @@ -1486,22 +1493,22 @@ def forward(self, x):



@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.")
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@unittest.skip("AOTI tests are failing right now")
class TestAOTI(unittest.TestCase):
@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
)
@run_supported_device_dtype
def test_aoti(self, api, test_device, test_dtype):
if not TORCH_VERSION_AT_LEAST_2_4:
self.skipTest("aoti compatibility requires 2.4+.")

print(f"TestAOTI: {api}, {test_device}, {test_dtype}")
logger.info(f"TestAOTI: {api}, {test_device}, {test_dtype}")
if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda":
self.skipTest(f"{api} in {test_device} is not support for aoti compilation yet")

if test_dtype != torch.bfloat16:
self.skipTest(f"{api} in {test_dtype} is not support for aoti compilation yet")
if test_device == "cuda" and torch.cuda.is_available() and test_dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
self.skipTest("Need CUDA and SM80+ available.")


logger.info(f"TestAOTI: {api}, {test_device}, {test_dtype}")

m, k, n = 32, 64, 32

Expand All @@ -1525,29 +1532,30 @@ def forward(self, x):
ref_f = model(x)

api(model)
unwrap_tensor_subclass(model)

# running model
model(x)

# make sure it compiles
torch._inductor.config.mixed_mm_choice = "triton"

example_inputs = (x,)
torch._export.aot_compile(model, example_inputs)
torch._inductor.aoti_compile_and_package(torch.export.export(model, example_inputs), example_inputs)


@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.")
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
class TestExport(unittest.TestCase):
@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
list(itertools.product(TENSOR_SUBCLASS_APIS + [_int8da_int4w_api], COMMON_DEVICES, COMMON_DTYPES)),
)
@run_supported_device_dtype
def test_export(self, api, test_device, test_dtype):
if not TORCH_VERSION_AT_LEAST_2_4:
self.skipTest("aoti compatibility requires 2.4+.")
if test_device == "cuda" and torch.cuda.is_available() and test_dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
self.skipTest("Need CUDA and SM80+ available.")

logger.info(f"TestExport: {api}, {test_device}, {test_dtype}")

if test_dtype != torch.bfloat16:
self.skipTest(f"{api} in {test_dtype} is not support for aoti compilation yet")

m, k, n = 32, 64, 32

class test_model(nn.Module):
Expand All @@ -1570,6 +1578,7 @@ def forward(self, x):
ref_f = model(x)

api(model)
unwrap_tensor_subclass(model)

# running model
ref = model(x)
Expand All @@ -1585,10 +1594,11 @@ def forward(self, x):
model = torch._export.capture_pre_autograd_graph(model, example_inputs)
after_export = model(x)
self.assertTrue(torch.equal(after_export, ref))
if api is _int8da_int8w_api:
if api is _int8da_int4w_api:
targets = [n.target for n in model.graph.nodes]
self.assertTrue(torch.ops.quant.choose_qparams_affine.default in targets)
self.assertTrue(torch.ops.quant.quantize_affine.default in targets)
self.assertFalse(torch.ops.aten.narrow.default in targets)



Expand Down
13 changes: 8 additions & 5 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,13 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
self.zero_point_domain,
output_dtype=output_dtype,
)
# need to return to original shape if tensor was padded
# in preprocessing
for dim, dim_size in enumerate(self.shape):
dq = dq.narrow(dim, 0, dim_size)
if isinstance(self._layout, TensorCoreTiledLayout):
# need to return to original shape if tensor was padded
# in preprocessing
# TODO: we could add an API for this if there are more use cases
# (e.g. dequant_post_process) in TensorImpl or Layout
for dim, dim_size in enumerate(self.shape):
dq = dq.narrow(dim, 0, dim_size)
return dq

@staticmethod
Expand Down Expand Up @@ -1698,7 +1701,7 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
y = y + bias
return y


Expand Down

0 comments on commit ccd883b

Please sign in to comment.