Skip to content

Commit e38ccee

Browse files
lint
1 parent ddf9d82 commit e38ccee

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

torchao/prototype/scaled_grouped_mm/kernels/jagged_float8_scales.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
"""
8-
Triton kernels for scaling high precision tensors to float8.
8+
Triton kernels for scaling high precision tensors to float8 using "jagged"
9+
rowwise scales (i.e., separate scales for each group/subtensor as determined by
10+
the offsets).
911
"""
10-
import itertools
12+
1113
from typing import Tuple
1214

1315
import torch
@@ -33,7 +35,9 @@
3335

3436
block_sizes = [128, 256]
3537
kernel_configs_2D = [
36-
triton.Config({"BLOCK_SIZE_ROWS": block_size_rows, "BLOCK_SIZE_COLS": block_size_cols})
38+
triton.Config(
39+
{"BLOCK_SIZE_ROWS": block_size_rows, "BLOCK_SIZE_COLS": block_size_cols}
40+
)
3741
for block_size_rows in block_sizes
3842
for block_size_cols in block_sizes
3943
]

torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,26 +54,26 @@ def forward(
5454
assert A.ndim == 2, "A must be 2D"
5555
assert B_t.ndim == 3, "B must be 3D"
5656

57-
assert A.size(-1) % 16 == 0, (
58-
f"A must have a last dim divisible by 16, but got shape: {A.shape}"
59-
)
60-
assert B_t.size(-2) % 16 == 0 and B_t.size(-1) % 16 == 0, (
61-
f"B must have last 2 dims divisible by 16, but got shape: {B_t.shape}"
62-
)
57+
assert (
58+
A.size(-1) % 16 == 0
59+
), f"A must have a last dim divisible by 16, but got shape: {A.shape}"
60+
assert (
61+
B_t.size(-2) % 16 == 0 and B_t.size(-1) % 16 == 0
62+
), f"B must have last 2 dims divisible by 16, but got shape: {B_t.shape}"
6363

6464
# Assert input tensors are in high-precision dtypes.
65-
assert A.dtype == torch.float32 or A.dtype == torch.bfloat16, (
66-
"A must be float32 or bfloat16"
67-
)
68-
assert B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16, (
69-
"B must be float32 or bfloat16"
70-
)
65+
assert (
66+
A.dtype == torch.float32 or A.dtype == torch.bfloat16
67+
), "A must be float32 or bfloat16"
68+
assert (
69+
B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16
70+
), "B must be float32 or bfloat16"
7171
assert offs.dtype == torch.int32, "offs must be int32"
7272

7373
# Assert A and B dims are compatible for a scaled grouped GEMM.
74-
assert A.size(-1) == B_t.size(-2), (
75-
f"shape {A.shape} and {B_t.shape} are not compatible for _scaled_grouped_mm"
76-
)
74+
assert A.size(-1) == B_t.size(
75+
-2
76+
), f"shape {A.shape} and {B_t.shape} are not compatible for _scaled_grouped_mm"
7777

7878
# The left operand in the scaled grouped GEMM must be row-major due to hardware requirements.
7979
assert not _is_column_major(A), "A must be row-major"

0 commit comments

Comments
 (0)