Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 12, 2021
1 parent a447b57 commit 588c5ab
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_dense_bias(M, N, K, out_dtype="float16"):


def get_dense_bias_relu(M, N, K, out_dtype="float16"):
return relay.nn.relu(get_dense_bias(M, N, K, out_dtype="float16"))
return relay.nn.relu(get_dense_bias(M, N, K, out_dtype=out_dtype))


def get_dense_bias_gelu(M, N, K, out_dtype="float16"):
Expand All @@ -110,17 +110,17 @@ def get_batch_matmul(batch, M, N, K, out_dtype="float16"):
return get_batch_matmul_with_shape((batch, M, K), (batch, N, K), out_dtype="float16")


def get_conv2d_nchw(d_shape, w_shape, out_dtype="float16"):
def get_conv2d_nchw(d_shape, w_shape, padding, out_dtype="float16"):
data = relay.var("data", shape=d_shape, dtype="float16")
weight = relay.var("weight", shape=w_shape, dtype="float16")
out_channel = w_shape[0]
return tvm.IRModule.from_expr(
relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=(3, 3),
kernel_size=w_shape[2:],
channels=out_channel,
padding=(1, 1),
padding=padding,
out_dtype=out_dtype,
)
)
Expand Down

0 comments on commit 588c5ab

Please sign in to comment.