Skip to content

Commit e2c7249

Browse files
add exception
1 parent c2bc6cf commit e2c7249

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

keras/src/backend/tensorflow/nn.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,14 +310,21 @@ def conv(
310310
):
311311
def _conv():
312312
tf_data_format = _convert_data_format(data_format, len(inputs.shape))
313-
return tf.nn.convolution(
313+
result = tf.nn.convolution(
314314
inputs,
315315
kernel,
316316
strides,
317317
padding.upper(),
318318
data_format=tf_data_format,
319319
dilations=dilation_rate,
320320
)
321+
if any(dim == 0 for dim in result.shape):
322+
raise ValueError(
323+
f"Convolution produced an output with size 0 dimension: "
324+
f"{result.shape}. Kernel size "
325+
f"cannot be greater than the padded input size."
326+
)
327+
return result
321328

322329
# Certain ops are are broken in Tensorflow on CPU only.
323330
# We can work around by compiling the op with XLA.

keras/src/layers/convolutional/conv_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,3 +1095,11 @@ def test_conv_constraints(self):
10951095
)
10961096
layer.build((None, 5, 5, 3))
10971097
self.assertIsInstance(layer.bias.constraint, constraints.NonNeg)
1098+
1099+
def test_conv_raises_exception_on_zero_dims(self):
1100+
x = np.random.rand(3, 4, 4, 4)
1101+
l = layers.Conv2D(6, [5, 5], 1, "valid")
1102+
with self.assertRaisesRegex(
1103+
ValueError, "Convolution produced an output with size 0 dimension"
1104+
):
1105+
l(x)

0 commit comments

Comments
 (0)