Skip to content
Open
2 changes: 2 additions & 0 deletions .github/scripts/ci_test_xpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ python3 -c "import torch; import torchao; print(f'Torch version: {torch.__versio
pip install pytest expecttest parameterized accelerate hf_transfer 'modelscope!=1.15.0'

pytest -v -s torchao/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py

pytest -v -s torchao/test/quantization/
35 changes: 22 additions & 13 deletions test/quantization/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
from torchao._models.llama.tokenizer import get_tokenizer
from torchao.quantization import Int4WeightOnlyConfig, quantize_
from torchao.quantization.utils import compute_error
from torchao.utils import auto_detect_device

torch.manual_seed(0)

_DEVICE = auto_detect_device()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto_detect_device seems to be changing what we want to test, I think previous we only want to test on CUDA, can you preserve this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have refine the auto_detect_device functions and cpu will not be included.



class TestGPTQ(TestCase):
@unittest.skip("skipping until we get checkpoints for gpt-fast")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just change this to torch.accelerator.is_available()?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_gptq_quantizer_int4_weight_only(self):
from torchao._models._eval import (
LMEvalInputRecorder,
Expand All @@ -33,7 +36,6 @@ def test_gptq_quantizer_int4_weight_only(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer

precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path(
"../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"
)
Expand Down Expand Up @@ -80,19 +82,19 @@ def test_gptq_quantizer_int4_weight_only(self):
)
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)

model = quantizer.quantize(model, *inputs).cuda()
model = quantizer.quantize(model, *inputs).to(_DEVICE)

model.reset_caches()
with torch.device("cuda"):
with torch.device(_DEVICE):
model.setup_caches(max_batch_size=1, max_seq_length=model.config.block_size)

limit = 1
result = TransformerEvalWrapper(
model.cuda(),
model.to(_DEVICE),
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
device,
_DEVICE,
).run_eval(
["wikitext"],
limit,
Expand All @@ -104,7 +106,7 @@ def test_gptq_quantizer_int4_weight_only(self):


class TestMultiTensorFlow(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't want to expand test to cpu I think

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_multitensor_add_tensors(self):
from torchao.quantization.GPTQ import MultiTensor

Expand All @@ -116,7 +118,7 @@ def test_multitensor_add_tensors(self):
self.assertTrue(torch.equal(mt.values[0], tensor1))
self.assertTrue(torch.equal(mt.values[1], tensor2))

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_multitensor_pad_unpad(self):
from torchao.quantization.GPTQ import MultiTensor

Expand All @@ -127,7 +129,7 @@ def test_multitensor_pad_unpad(self):
mt.unpad()
self.assertEqual(mt.count, 1)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_multitensor_inplace_operation(self):
from torchao.quantization.GPTQ import MultiTensor

Expand All @@ -138,7 +140,7 @@ def test_multitensor_inplace_operation(self):


class TestMultiTensorInputRecorder(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_multitensor_input_recorder(self):
from torchao.quantization.GPTQ import MultiTensor, MultiTensorInputRecorder

Expand All @@ -159,7 +161,7 @@ def test_multitensor_input_recorder(self):
self.assertTrue(isinstance(MT_input[2][2], MultiTensor))
self.assertEqual(MT_input[3], torch.float)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_gptq_with_input_recorder(self):
from torchao.quantization.GPTQ import (
Int4WeightOnlyGPTQQuantizer,
Expand All @@ -170,7 +172,7 @@ def test_gptq_with_input_recorder(self):

config = ModelArgs(n_layer=2)

with torch.device("cuda"):
with torch.device(_DEVICE):
model = Transformer(config)
model.setup_caches(max_batch_size=2, max_seq_length=100)
idx = torch.randint(1, 10000, (10, 2, 50)).to(torch.int32)
Expand All @@ -191,7 +193,14 @@ def test_gptq_with_input_recorder(self):

args = input_recorder.get_recorded_inputs()

quantizer = Int4WeightOnlyGPTQQuantizer()
if _DEVICE.type == "xpu":
from torchao.dtypes import Int4XPULayout

quantizer = Int4WeightOnlyGPTQQuantizer(
device=torch.device("xpu"), layout=Int4XPULayout()
)
else:
quantizer = Int4WeightOnlyGPTQQuantizer()

quantizer.quantize(model, *args)

Expand Down
38 changes: 17 additions & 21 deletions test/quantization/test_moe_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@
quantize_,
)
from torchao.quantization.utils import compute_error
from torchao.utils import is_sm_at_least_90
from torchao.testing.utils import skip_if_no_cuda
from torchao.utils import (
auto_detect_device,
is_sm_at_least_90,
)

_DEVICE = auto_detect_device()

if torch.version.hip is not None:
pytest.skip(
Expand All @@ -54,7 +60,7 @@ def _test_impl_moe_quant(
base_class=AffineQuantizedTensor,
tensor_impl_class=None,
dtype=torch.bfloat16,
device="cuda",
device=_DEVICE,
fullgraph=False,
):
"""
Expand Down Expand Up @@ -115,10 +121,8 @@ def _test_impl_moe_quant(
("multiple_tokens", 8, False),
]
)
@skip_if_no_cuda()
def test_int4wo_fake_dim(self, name, num_tokens, fullgraph):
if not torch.cuda.is_available():
self.skipTest("Need CUDA available")

config = MoEQuantConfig(
Int4WeightOnlyConfig(version=1),
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
Expand All @@ -138,6 +142,7 @@ def test_int4wo_fake_dim(self, name, num_tokens, fullgraph):
("multiple_tokens", 8, False),
]
)
@skip_if_no_cuda()
def test_int4wo_base(self, name, num_tokens, fullgraph):
if not torch.cuda.is_available():
self.skipTest("Need CUDA available")
Expand All @@ -160,10 +165,8 @@ def test_int4wo_base(self, name, num_tokens, fullgraph):
("multiple_tokens", 8, False),
]
)
@skip_if_no_cuda()
def test_int8wo_fake_dim(self, name, num_tokens, fullgraph):
if not torch.cuda.is_available():
self.skipTest("Need CUDA available")

config = MoEQuantConfig(
Int8WeightOnlyConfig(), use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
)
Expand All @@ -182,10 +185,8 @@ def test_int8wo_fake_dim(self, name, num_tokens, fullgraph):
("multiple_tokens", 8, False),
]
)
@skip_if_no_cuda()
def test_int8wo_base(self, name, num_tokens, fullgraph):
if not torch.cuda.is_available():
self.skipTest("Need CUDA available")

config = MoEQuantConfig(Int8WeightOnlyConfig())
tensor_impl_class = PlainAQTTensorImpl

Expand All @@ -202,6 +203,7 @@ def test_int8wo_base(self, name, num_tokens, fullgraph):
("multiple_tokens", 8, False),
]
)
@skip_if_no_cuda()
def test_int8wo_base_cpu(self, name, num_tokens, fullgraph):
config = MoEQuantConfig(Int8WeightOnlyConfig())
tensor_impl_class = PlainAQTTensorImpl
Expand All @@ -219,10 +221,8 @@ def test_int8wo_base_cpu(self, name, num_tokens, fullgraph):
("multiple_tokens", 32, False),
]
)
@skip_if_no_cuda()
def test_int8dq_fake_dim(self, name, num_tokens, fullgraph):
if not torch.cuda.is_available():
self.skipTest("Need CUDA available")

config = MoEQuantConfig(
Int8DynamicActivationInt8WeightConfig(),
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
Expand All @@ -242,10 +242,8 @@ def test_int8dq_fake_dim(self, name, num_tokens, fullgraph):
("multiple_tokens", 32, False),
]
)
@skip_if_no_cuda()
def test_int8dq_base(self, name, num_tokens, fullgraph):
if not torch.cuda.is_available():
self.skipTest("Need CUDA available")

config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
base_class = LinearActivationQuantizedTensor

Expand All @@ -263,9 +261,8 @@ def test_int8dq_base(self, name, num_tokens, fullgraph):
("multiple_tokens", 8, False),
]
)
@skip_if_no_cuda()
def test_fp8wo_fake_dim(self, name, num_tokens, fullgraph):
if not torch.cuda.is_available():
self.skipTest("Need CUDA available")
if not is_sm_at_least_90():
self.skipTest("Requires CUDA capability >= 9.0")

Expand Down Expand Up @@ -335,9 +332,8 @@ def test_fp8dq_fake_dim(self, name, num_tokens, fullgraph):
("multiple_tokens", 8, False),
]
)
@skip_if_no_cuda()
def test_fp8dq_base(self, name, num_tokens, fullgraph):
if not torch.cuda.is_available():
self.skipTest("Need CUDA available")
if not is_sm_at_least_90():
self.skipTest("Requires CUDA capability >= 9.0")

Expand Down
14 changes: 10 additions & 4 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,15 @@
)
from torchao.utils import (
_is_fbgemm_gpu_genai_available,
auto_detect_device,
is_fbcode,
is_sm_at_least_89,
)

# TODO: put this in a common test utils file
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
_GPU_IS_AVAILABLE = torch.accelerator.is_available()
_DEVICE = auto_detect_device()


class Sub(torch.nn.Module):
Expand Down Expand Up @@ -347,7 +350,7 @@ def _set_ptq_weight(
group_size,
)
q_weight = torch.ops.aten._convert_weight_to_int4pack(
q_weight.to("cuda"),
q_weight.to(_DEVICE),
qat_linear.inner_k_tiles,
)
ptq_linear.weight = q_weight
Expand Down Expand Up @@ -600,13 +603,15 @@ def _assert_close_4w(self, val, ref):
print(mean_err)
self.assertTrue(mean_err < 0.05)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(
not _GPU_IS_AVAILABLE, "skipping when cuda or xpu is not available"
)
def test_qat_4w_primitives(self):
n_bit = 4
group_size = 32
inner_k_tiles = 8
scales_precision = torch.bfloat16
device = torch.device("cuda")
device = torch.device(_DEVICE)
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
x = torch.randn(100, 256, dtype=dtype, device=device)
Expand Down Expand Up @@ -699,11 +704,12 @@ def test_qat_4w_quantizer(self):

group_size = 32
inner_k_tiles = 8
device = torch.device("cuda")
device = torch.device(_DEVICE)
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
m = M().to(device).to(dtype)
m2 = copy.deepcopy(m)

qat_quantizer = Int4WeightOnlyQATQuantizer(
groupsize=group_size,
inner_k_tiles=inner_k_tiles,
Expand Down
Loading
Loading