@@ -54,26 +54,26 @@ def forward(
54
54
assert A .ndim == 2 , "A must be 2D"
55
55
assert B_t .ndim == 3 , "B must be 3D"
56
56
57
- assert A . size ( - 1 ) % 16 == 0 , (
58
- f"A must have a last dim divisible by 16, but got shape: { A . shape } "
59
- )
60
- assert B_t . size ( - 2 ) % 16 == 0 and B_t . size ( - 1 ) % 16 == 0 , (
61
- f"B must have last 2 dims divisible by 16, but got shape: { B_t . shape } "
62
- )
57
+ assert (
58
+ A . size ( - 1 ) % 16 == 0
59
+ ), f"A must have a last dim divisible by 16, but got shape: { A . shape } "
60
+ assert (
61
+ B_t . size ( - 2 ) % 16 == 0 and B_t . size ( - 1 ) % 16 == 0
62
+ ), f"B must have last 2 dims divisible by 16, but got shape: { B_t . shape } "
63
63
64
64
# Assert input tensors are in high-precision dtypes.
65
- assert A . dtype == torch . float32 or A . dtype == torch . bfloat16 , (
66
- "A must be float32 or bfloat16"
67
- )
68
- assert B_t . dtype == torch . float32 or B_t . dtype == torch . bfloat16 , (
69
- "B must be float32 or bfloat16"
70
- )
65
+ assert (
66
+ A . dtype == torch . float32 or A . dtype == torch . bfloat16
67
+ ), "A must be float32 or bfloat16"
68
+ assert (
69
+ B_t . dtype == torch . float32 or B_t . dtype == torch . bfloat16
70
+ ), "B must be float32 or bfloat16"
71
71
assert offs .dtype == torch .int32 , "offs must be int32"
72
72
73
73
# Assert A and B dims are compatible for a scaled grouped GEMM.
74
- assert A .size (- 1 ) == B_t .size (- 2 ), (
75
- f"shape { A . shape } and { B_t . shape } are not compatible for _scaled_grouped_mm"
76
- )
74
+ assert A .size (- 1 ) == B_t .size (
75
+ - 2
76
+ ), f"shape { A . shape } and { B_t . shape } are not compatible for _scaled_grouped_mm"
77
77
78
78
# The left operand in the scaled grouped GEMM must be row-major due to hardware requirements.
79
79
assert not _is_column_major (A ), "A must be row-major"
0 commit comments