Skip to content

Commit fda836f

Browse files
committed
add missing F8_E4M3 and F8_E5M2 keys from str_to_torch_dtype
1 parent 3f76848 commit fda836f

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/transformers/modeling_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,10 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
535535
str_to_torch_dtype["U32"] = torch.uint32
536536
str_to_torch_dtype["U64"] = torch.uint64
537537

538+
if is_torch_greater_or_equal("2.1.0"):
539+
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn
540+
str_to_torch_dtype["F8_E5M2"] = torch.float8_e5m2
541+
538542

539543
def load_state_dict(
540544
checkpoint_file: Union[str, os.PathLike],

0 commit comments

Comments
 (0)