Skip to content

Commit 0e874fb

Browse files
committed
fix bug with float8 + inference_mode
Summary: As titled, just needed to add `reshape` to supported ops. Test Plan: ``` pytest test/float8/test_base.py -s -x -k inference_mode ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 9eef1ae commit 0e874fb

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

test/float8/test_base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,14 @@ def test_repr(self):
390390
s = m.__repr__()
391391
assert "i:dyn,w:del,go:dyn" in s
392392

393+
@unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available")
394+
def test_inference_mode(self):
395+
x = torch.randn(32, 32, device='cuda')
396+
m = nn.Sequential(nn.Linear(32, 32)).cuda()
397+
m = convert_to_float8_training(m)
398+
with torch.inference_mode(mode=True):
399+
y = m(x)
400+
393401

394402
class TestScaledMM:
395403
@unittest.skipIf(
@@ -718,8 +726,6 @@ def test_fp8_tensor_statistics(self):
718726
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_over_underflow, lp_dtype)
719727
self.assertEqual((zero_cnt, max_cnt), (tensor_len, tensor_len))
720728

721-
# ghstack test 1
722-
# ghstack test 2
723729

724730
if __name__ == "__main__":
725731
pytest.main([__file__])

torchao/float8/float8_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def decorator(func):
4141
aten.slice.Tensor,
4242
aten.transpose.int,
4343
aten.fill_.Scalar,
44+
aten.reshape.default,
4445
]
4546
)
4647
def float8_desugar_op(aten_op, args, kwargs=None):

0 commit comments

Comments
 (0)