Skip to content

Commit 5439f42

Browse files
y-sqfacebook-github-bot
authored andcommitted
Remove two if statements in fp8 padding
Reviewed By: vkuzo Differential Revision: D63051205
1 parent 26e790d commit 5439f42

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

torchao/float8/float8_utils.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,7 @@ def _get_min_alignment(size: int, alignment_value: int) -> int:
196196
16
197197
```
198198
"""
199-
if size % alignment_value == 0:
200-
return size
201-
return (1 + (size // alignment_value)) * alignment_value
199+
return (1 + ((size-1) // alignment_value)) * alignment_value
202200

203201

204202
def pad_tensor_for_matmul(
@@ -234,10 +232,6 @@ def pad_tensor_for_matmul(
234232
dim1_aligned = _get_min_alignment(dim1, 16) if 0 in dims else dim1
235233
dim2_aligned = _get_min_alignment(dim2, 16) if 1 in dims else dim2
236234

237-
# Check if padding is needed for either dimension
238-
if dim1 == dim1_aligned and dim2 == dim2_aligned:
239-
return tensor
240-
241235
# Calculate padding values for both dimensions
242236
pad_dim1 = dim1_aligned - dim1
243237
pad_dim2 = dim2_aligned - dim2

0 commit comments

Comments
 (0)