Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing cuda device check #536

Merged
merged 1 commit into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,17 +642,19 @@ def test_int8wo_quantized_model_to_device(self):
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
def test_int4wo_quantized_model_to_device(self):
# TODO: change initial model to "cpu"
m = ToyLinearModel().eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")

quantize_(m, int4_weight_only())
ref = m(*example_inputs)

example_inputs_cuda = (example_inputs[0].to("cuda"),)
m.to(device="cuda")
cuda_res = m(*example_inputs_cuda)
self.assertEqual(cuda_res.cpu(), ref)
devices = ["cuda", "cuda:0"]
for device in devices:
m = ToyLinearModel().eval().to(torch.bfloat16).to(device)
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device)

quantize_(m, int4_weight_only())
ref = m(*example_inputs)

example_inputs_cuda = (example_inputs[0].to(device),)
m.to(device=device)
cuda_res = m(*example_inputs_cuda)
self.assertEqual(cuda_res.cpu(), ref)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
3 changes: 2 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_register_layout_cls,
_get_layout_tensor_constructor,
LayoutType,
is_device,
)
from typing import ClassVar
from dataclasses import dataclass
Expand Down Expand Up @@ -544,7 +545,7 @@ def from_plain(
def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
device = kwargs["device"]
if device != "cuda" and (isinstance(device, torch.device) and device.type != "cuda"):
if not is_device("cuda", device):
raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}")
return self.__class__(
self.packed_weight.to(device),
Expand Down
5 changes: 4 additions & 1 deletion torchao/dtypes/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from typing import Dict, Callable
from typing import Dict, Callable, Union
from collections import defaultdict
import functools
from dataclasses import dataclass
Expand Down Expand Up @@ -89,3 +89,6 @@ def _get_layout_tensor_constructor(cls: Callable, layout_type_class: type(Layout
raise ValueError(f"layout_name: {layout_type_class} is not supported yet for {cls}")

return _LAYOUT_CONSTRUCTOR_TABLE[cls][layout_type_class]

def is_device(target_device_str: str, device: Union[str, torch.device]):
return torch.device(device).type == target_device_str
Loading