Skip to content

Commit 12989d6

Browse files
committed
black
1 parent 5eb487a commit 12989d6

File tree

4 files changed

+10
-15
lines changed

4 files changed

+10
-15
lines changed

python/tvm/contrib/cutlass/gen_gemm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,12 @@ def enumerate_gemm_operators(
8989
C = TensorDescription(element_c, LayoutType.RowMajor, alignment)
9090

9191
if element_c == DataType.s32 and A.alignment == 1:
92-
tile_description.threadblock_shape[0] = min(tile_description.threadblock_shape[0], 128)
93-
tile_description.threadblock_shape[1] = min(tile_description.threadblock_shape[1], 128)
92+
tile_description.threadblock_shape[0] = min(
93+
tile_description.threadblock_shape[0], 128
94+
)
95+
tile_description.threadblock_shape[1] = min(
96+
tile_description.threadblock_shape[1], 128
97+
)
9498

9599
op = GemmOperation(
96100
tile_description.minimum_compute_capability,

python/tvm/relay/op/strategy/cuda.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -872,18 +872,8 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
872872
or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0)
873873
)
874874
)
875-
or (
876-
data.dtype in ["int4", "uint4"]
877-
and i % 32 == 0
878-
and b % 8 == 0
879-
and o % 8 == 0
880-
)
881-
or (
882-
data.dtype in ["int1", "uint1"]
883-
and i % 128 == 0
884-
and b % 8 == 0
885-
and o % 8 == 0
886-
)
875+
or (data.dtype in ["int4", "uint4"] and i % 32 == 0 and b % 8 == 0 and o % 8 == 0)
876+
or (data.dtype in ["int1", "uint1"] and i % 128 == 0 and b % 8 == 0 and o % 8 == 0)
887877
):
888878
strategy.add_implementation(
889879
wrap_compute_dense(topi.cuda.dense_tensorcore),

tests/python/contrib/test_cublas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def test_batch_matmul():
169169

170170
verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "int8", "int32")
171171

172+
172173
if __name__ == "__main__":
173174
test_matmul_add()
174175
test_batch_matmul()

tests/python/contrib/test_cutlass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def test_conv2d():
742742
rtol=1e-5,
743743
ref_target="llvm",
744744
data_dtype="uint8",
745-
weight_dtype="int8"
745+
weight_dtype="int8",
746746
)
747747

748748

0 commit comments

Comments
 (0)