Skip to content

Commit c22de4b

Browse files
Fix the Bug in func preprocess_input when x in 3D and data_format=='channels_first' (keras-team#21750)
* Fix the Bug in func `preprocess_input` when `x` in 3D and `data_format=='channels_first'` * Update keras/src/applications/imagenet_utils.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent e45fedb commit c22de4b

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

keras/src/applications/imagenet_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,10 @@ def _preprocess_tensor_input(x, data_format, mode):
278278

279279
# Zero-center by mean pixel
280280
if data_format == "channels_first":
281-
mean_tensor = ops.reshape(mean_tensor, (1, 3) + (1,) * (ndim - 2))
281+
if len(x.shape) == 3:
282+
mean_tensor = ops.reshape(mean_tensor, (3, 1, 1))
283+
else:
284+
mean_tensor = ops.reshape(mean_tensor, (1, 3) + (1,) * (ndim - 2))
282285
else:
283286
mean_tensor = ops.reshape(mean_tensor, (1,) * (ndim - 1) + (3,))
284287
x += mean_tensor

0 commit comments

Comments
 (0)