diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index 3d05058ff52cc..e6631d57b29ea 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -450,6 +450,10 @@ def _conv2d_legalize(attrs, inputs, arg_types): elif data_dtype in ["float16"]: if data_layout == "NHWC" and kernel_layout == "HWIO": + if isinstance(data_tensor.shape[0], tvm.tir.expr.Any): + # Skip legalize when the batch size is dynamic + return None + batch = data_tensor.shape[0].value in_channel = data_tensor.shape[3].value out_channel = kernel_tensor.shape[3].value