Skip to content

Commit fc6c393

Browse files
y-sqweifengpy
authored andcommitted
Remove two if statements in fp8 padding (pytorch#935)
Reviewed By: vkuzo Differential Revision: D63051205 Pull Request resolved: pytorch#935 Approved by: https://github.com/vkuzo
1 parent 3a9fdb0 commit fc6c393

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

torchao/float8/float8_utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,7 @@ def _get_min_alignment(size: int, alignment_value: int) -> int:
198198
16
199199
```
200200
"""
201-
if size % alignment_value == 0:
202-
return size
203-
return (1 + (size // alignment_value)) * alignment_value
201+
return (1 + ((size - 1) // alignment_value)) * alignment_value
204202

205203

206204
def pad_tensor_for_matmul(
@@ -236,10 +234,6 @@ def pad_tensor_for_matmul(
236234
dim1_aligned = _get_min_alignment(dim1, 16) if 0 in dims else dim1
237235
dim2_aligned = _get_min_alignment(dim2, 16) if 1 in dims else dim2
238236

239-
# Check if padding is needed for either dimension
240-
if dim1 == dim1_aligned and dim2 == dim2_aligned:
241-
return tensor
242-
243237
# Calculate padding values for both dimensions
244238
pad_dim1 = dim1_aligned - dim1
245239
pad_dim2 = dim2_aligned - dim2

0 commit comments

Comments
 (0)