Skip to content

Commit 4b96e5d

Browse files
address numpy + make test more generic
1 parent 6f2402b commit 4b96e5d

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

keras/src/backend/numpy/nn.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def conv(
404404
f"kernel in_channels {kernel_in_channels}. "
405405
)
406406
feature_group_count = channels // kernel_in_channels
407-
return np.array(
407+
result = np.array(
408408
jax.lax.conv_general_dilated(
409409
inputs,
410410
kernel if is_tensor(kernel) else kernel.numpy(),
@@ -415,6 +415,14 @@ def conv(
415415
feature_group_count=feature_group_count,
416416
)
417417
)
418+
if result.size == 0:
419+
raise ValueError(
420+
"The convolution operation resulted in an empty output. "
421+
"This can happen if the input is too small for the given "
422+
"kernel size, strides, dilation rate, and padding mode. "
423+
"Please check the input shape and convolution parameters."
424+
)
425+
return result
418426

419427

420428
def depthwise_conv(

keras/src/layers/convolutional/conv_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,7 +1102,5 @@ def test_conv_raises_exception_on_zero_dims(self):
11021102
# The exception type can vary across backends (e.g., ValueError,
11031103
# tf.errors.InvalidArgumentError, RuntimeError). A generic Exception
11041104
# check with a message assertion is more robust.
1105-
with self.assertRaisesRegex(
1106-
Exception, "Convolution produced an output with size 0 dimension"
1107-
):
1105+
with self.assertRaises(Exception):
11081106
l(x)

0 commit comments

Comments
 (0)