Skip to content

Commit 84c5d77

Browse files
malfetamathewc
authored andcommitted
[MPSInductor] torch.complex128 is unsupported on MPS (pytorch#150386)
Same as torch.float64 Pull Request resolved: pytorch#150386 Approved by: https://github.com/dcci ghstack dependencies: pytorch#150382
1 parent 2195892 commit 84c5d77

File tree

3 files changed

+4
-1
lines changed

3 files changed

+4
-1
lines changed

test/inductor/test_mps_basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def fn(a):
162162
# Copy tests
163163
for test_name in [
164164
"test_min_max_reduction",
165+
"test_add_complex4",
165166
"test_add_const_int",
166167
"test_add_inplace_permuted",
167168
"test_addmm",

test/inductor/test_torchinductor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,6 +1368,8 @@ def fn(a, b):
13681368
return c + d
13691369

13701370
for dtype in [torch.complex32, torch.complex64, torch.complex128]:
1371+
if not self.is_dtype_supported(dtype):
1372+
continue
13711373
x = torch.tensor(
13721374
[1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1],
13731375
dtype=dtype,

torch/_dynamo/device_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def is_bf16_supported(including_emulation: bool = False) -> bool:
376376
def is_dtype_supported(
377377
cls, dtype: torch.dtype, including_emulation: bool = False
378378
) -> bool:
379-
if dtype == torch.float64:
379+
if dtype in [torch.float64, torch.complex128]:
380380
return False
381381
return dtype != torch.bfloat16 or cls.is_bf16_supported(including_emulation)
382382

0 commit comments

Comments
 (0)