File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
torchao/prototype/mx_formats Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -1295,7 +1295,7 @@ def to_mxfp8_dim1_kernel(
1295
1295
# TODO(future): mask this store
1296
1296
tl .store (col_scale_start_ptr + col_scale_indices , col_scale_e8m0 )
1297
1297
1298
- def to_mxfp8_dim1 (x , inner_block_size = 32 ):
1298
+ def to_mxfp8_dim1 (x , inner_block_size = 32 ) -> Tuple [ torch . Tensor , torch . Tensor ] :
1299
1299
"""
1300
1300
Input:
1301
1301
* `x` - input tensor, in row major memory layout
@@ -1373,10 +1373,10 @@ def to_mxfp8_dim1_reference(
1373
1373
1374
1374
else :
1375
1375
1376
- def to_mxfp8_across_dim0_and_dim1 (x , tile_size = 32 ):
1376
+ def to_mxfp8_dim1 (x , inner_block_size = 32 ) -> Tuple [ torch . Tensor , torch . Tensor ] :
1377
1377
raise AssertionError ("needs torch version 2.8+ and triton" )
1378
1378
1379
- def scale_dim0_dim1_reference (
1379
+ def to_mxfp8_dim1_reference (
1380
1380
x_hp : torch .Tensor , block_size
1381
- ) -> Tuple [torch .Tensor , torch .Tensor , torch . Tensor , torch . Tensor ]:
1381
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1382
1382
raise AssertionError ("needs torch version 2.8+ and triton" )
You can’t perform that action at this time.
0 commit comments