Skip to content

Commit 32005c9

Browse files
committed
Update
[ghstack-poisoned]
1 parent 483cdfd commit 32005c9

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torchao/prototype/mx_formats/custom_cast.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,7 +1295,7 @@ def to_mxfp8_dim1_kernel(
12951295
# TODO(future): mask this store
12961296
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)
12971297

1298-
def to_mxfp8_dim1(x, inner_block_size=32):
1298+
def to_mxfp8_dim1(x, inner_block_size=32) -> Tuple[torch.Tensor, torch.Tensor]:
12991299
"""
13001300
Input:
13011301
* `x` - input tensor, in row major memory layout
@@ -1373,10 +1373,10 @@ def to_mxfp8_dim1_reference(
13731373

13741374
else:
13751375

1376-
def to_mxfp8_across_dim0_and_dim1(x, tile_size=32):
1376+
def to_mxfp8_dim1(x, inner_block_size=32) -> Tuple[torch.Tensor, torch.Tensor]:
13771377
raise AssertionError("needs torch version 2.8+ and triton")
13781378

1379-
def scale_dim0_dim1_reference(
1379+
def to_mxfp8_dim1_reference(
13801380
x_hp: torch.Tensor, block_size
1381-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1381+
) -> Tuple[torch.Tensor, torch.Tensor]:
13821382
raise AssertionError("needs torch version 2.8+ and triton")

0 commit comments

Comments
 (0)