Skip to content

Commit 0eeba14

Browse files
Leave data_format as None in keras.ops.images and add data_format support to pad_images and crop_images (keras-team#19774)
* Fix `data_format` * Add `data_format` to `pad_images` and `crop_images` * Fix CI
1 parent d59ccca commit 0eeba14

File tree

6 files changed

+409
-219
lines changed

6 files changed

+409
-219
lines changed

keras/src/backend/jax/image.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import jax
44
import jax.numpy as jnp
55

6+
from keras.src import backend
67
from keras.src.backend.jax.core import convert_to_tensor
78

89
RESIZE_INTERPOLATIONS = (
@@ -14,7 +15,8 @@
1415
)
1516

1617

17-
def rgb_to_grayscale(image, data_format="channels_last"):
18+
def rgb_to_grayscale(image, data_format=None):
19+
data_format = backend.standardize_data_format(data_format)
1820
if data_format == "channels_first":
1921
if len(image.shape) == 4:
2022
image = jnp.transpose(image, (0, 2, 3, 1))
@@ -46,8 +48,9 @@ def resize(
4648
pad_to_aspect_ratio=False,
4749
fill_mode="constant",
4850
fill_value=0.0,
49-
data_format="channels_last",
51+
data_format=None,
5052
):
53+
data_format = backend.standardize_data_format(data_format)
5154
if interpolation not in RESIZE_INTERPOLATIONS:
5255
raise ValueError(
5356
"Invalid value for argument `interpolation`. Expected of one "
@@ -297,8 +300,9 @@ def affine_transform(
297300
interpolation="bilinear",
298301
fill_mode="constant",
299302
fill_value=0,
300-
data_format="channels_last",
303+
data_format=None,
301304
):
305+
data_format = backend.standardize_data_format(data_format)
302306
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
303307
raise ValueError(
304308
"Invalid value for argument `interpolation`. Expected of one "

keras/src/backend/numpy/image.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import jax
22
import numpy as np
33

4+
from keras.src import backend
45
from keras.src.backend.numpy.core import convert_to_tensor
56
from keras.src.utils.module_utils import scipy
67

@@ -13,7 +14,8 @@
1314
)
1415

1516

16-
def rgb_to_grayscale(image, data_format="channels_last"):
17+
def rgb_to_grayscale(image, data_format=None):
18+
data_format = backend.standardize_data_format(data_format)
1719
if data_format == "channels_first":
1820
if len(image.shape) == 4:
1921
image = np.transpose(image, (0, 2, 3, 1))
@@ -45,8 +47,9 @@ def resize(
4547
pad_to_aspect_ratio=False,
4648
fill_mode="constant",
4749
fill_value=0.0,
48-
data_format="channels_last",
50+
data_format=None,
4951
):
52+
data_format = backend.standardize_data_format(data_format)
5053
if interpolation not in RESIZE_INTERPOLATIONS:
5154
raise ValueError(
5255
"Invalid value for argument `interpolation`. Expected of one "
@@ -231,8 +234,9 @@ def affine_transform(
231234
interpolation="bilinear",
232235
fill_mode="constant",
233236
fill_value=0,
234-
data_format="channels_last",
237+
data_format=None,
235238
):
239+
data_format = backend.standardize_data_format(data_format)
236240
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
237241
raise ValueError(
238242
"Invalid value for argument `interpolation`. Expected of one "

keras/src/backend/tensorflow/image.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import tensorflow as tf
66

7+
from keras.src import backend
78
from keras.src.backend.tensorflow.core import convert_to_tensor
89

910
RESIZE_INTERPOLATIONS = (
@@ -15,7 +16,8 @@
1516
)
1617

1718

18-
def rgb_to_grayscale(image, data_format="channels_last"):
19+
def rgb_to_grayscale(image, data_format=None):
20+
data_format = backend.standardize_data_format(data_format)
1921
if data_format == "channels_first":
2022
if len(image.shape) == 4:
2123
image = tf.transpose(image, (0, 2, 3, 1))
@@ -45,8 +47,9 @@ def resize(
4547
pad_to_aspect_ratio=False,
4648
fill_mode="constant",
4749
fill_value=0.0,
48-
data_format="channels_last",
50+
data_format=None,
4951
):
52+
data_format = backend.standardize_data_format(data_format)
5053
if interpolation not in RESIZE_INTERPOLATIONS:
5154
raise ValueError(
5255
"Invalid value for argument `interpolation`. Expected of one "
@@ -255,8 +258,9 @@ def affine_transform(
255258
interpolation="bilinear",
256259
fill_mode="constant",
257260
fill_value=0,
258-
data_format="channels_last",
261+
data_format=None,
259262
):
263+
data_format = backend.standardize_data_format(data_format)
260264
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS:
261265
raise ValueError(
262266
"Invalid value for argument `interpolation`. Expected of one "

keras/src/backend/torch/image.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
from keras.src import backend
78
from keras.src.backend.torch.core import convert_to_tensor
89

910
RESIZE_INTERPOLATIONS = {} # populated after torchvision import
@@ -14,7 +15,8 @@
1415
)
1516

1617

17-
def rgb_to_grayscale(image, data_format="channel_last"):
18+
def rgb_to_grayscale(image, data_format=None):
19+
data_format = backend.standardize_data_format(data_format)
1820
try:
1921
import torchvision
2022
except:
@@ -55,8 +57,9 @@ def resize(
5557
pad_to_aspect_ratio=False,
5658
fill_mode="constant",
5759
fill_value=0.0,
58-
data_format="channels_last",
60+
data_format=None,
5961
):
62+
data_format = backend.standardize_data_format(data_format)
6063
try:
6164
import torchvision
6265
from torchvision.transforms import InterpolationMode as im
@@ -216,8 +219,9 @@ def affine_transform(
216219
interpolation="bilinear",
217220
fill_mode="constant",
218221
fill_value=0,
219-
data_format="channels_last",
222+
data_format=None,
220223
):
224+
data_format = backend.standardize_data_format(data_format)
221225
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
222226
raise ValueError(
223227
"Invalid value for argument `interpolation`. Expected of one "

0 commit comments

Comments
 (0)