Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 61791d8

Browse files
committed
make to more restrictive
1 parent 2789085 commit 61791d8

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

float8_experimental/float8_ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,15 @@ def float8_is_same_size(aten_op, args, kwargs=None):
7171

7272
@implements([aten._to_copy.default])
7373
def autocast_to_copy(aten_op, args, kwargs=None):
74-
# TODO Think about the scale propagation with autocast
74+
""" This gets called when running matmul under autocast
75+
when the input is a Float8Tensor, presenting as a fp32
76+
tensor.
77+
"""
7578
assert isinstance(args[0], Float8Tensor)
7679
assert len(kwargs) == 1 and "dtype" in kwargs, "Only support dtype kwarg for autocast"
7780
assert kwargs[
7881
"dtype"
79-
].is_floating_point, "Only support floating point conversion for autocast w/ Float8Tensor"
82+
] == torch.float16, "Only support floating point conversion for autocast w/ Float8Tensor"
8083
return Float8Tensor(
8184
args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._buffer_refs, args[0]._emulate
8285
)

0 commit comments

Comments
 (0)