Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 39 additions & 30 deletions keras/src/backend/tensorflow/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,23 +431,9 @@ def map_coordinates(
f" Received input with shape: {coordinate_arrs.shape}"
)

# unstack into a list of tensors for following operations
coordinate_arrs = tf.unstack(coordinate_arrs, axis=0)
fill_value = convert_to_tensor(tf.cast(fill_value, input_arr.dtype))

index_fixer = _INDEX_FIXERS.get(fill_mode)
if index_fixer is None:
raise ValueError(
"Invalid value for argument `fill_mode`. Expected one of "
f"{set(_INDEX_FIXERS.keys())}. Received: "
f"fill_mode={fill_mode}"
)
fill_value = convert_to_tensor(fill_value, dtype=input_arr.dtype)

def is_valid(index, size):
if fill_mode == "constant":
return (0 <= index) & (index < size)
else:
return True
coordinate_arrs = tf.unstack(coordinate_arrs, axis=0)

if order == 0:
interp_fun = _nearest_indices_and_weights
Expand All @@ -456,38 +442,61 @@ def is_valid(index, size):
else:
raise NotImplementedError("map_coordinates currently requires order<=1")

def process_coordinates(coords, size):
if fill_mode == "constant":
valid = (coords >= 0) & (coords < size)
safe_coords = tf.clip_by_value(coords, 0, size - 1)
return safe_coords, valid
elif fill_mode == "nearest":
return tf.clip_by_value(coords, 0, size - 1), tf.ones_like(
coords, dtype=tf.bool
)
elif fill_mode in ["mirror", "reflect"]:
coords = tf.abs(coords)
size_2 = size * 2
mod = tf.math.mod(coords, size_2)
under = mod < size
over = ~under
# reflect mode is same as mirror for under
coords = tf.where(under, mod, size_2 - mod)
# for reflect mode, adjust the over case
if fill_mode == "reflect":
coords = tf.where(over, coords - 1, coords)
return coords, tf.ones_like(coords, dtype=tf.bool)
elif fill_mode == "wrap":
coords = tf.math.mod(coords, size)
return coords, tf.ones_like(coords, dtype=tf.bool)
else:
raise ValueError(f"Unknown fill_mode: {fill_mode}")

valid_1d_interpolations = []
for coordinate, size in zip(coordinate_arrs, input_arr.shape):
interp_nodes = interp_fun(coordinate)
valid_interp = []
for index, weight in interp_nodes:
fixed_index = index_fixer(index, size)
valid = is_valid(index, size)
valid_interp.append((fixed_index, valid, weight))
safe_index, valid = process_coordinates(index, size)
valid_interp.append((safe_index, valid, weight))
valid_1d_interpolations.append(valid_interp)

outputs = []
for items in itertools.product(*valid_1d_interpolations):
indices, validities, weights = zip(*items)
indices = tf.transpose(tf.stack(indices))

def fast_path():
return tf.transpose(tf.gather_nd(input_arr, indices))
gathered = tf.transpose(tf.gather_nd(input_arr, indices))

def slow_path():
all_valid = functools.reduce(operator.and_, validities)
return tf.where(
all_valid,
tf.transpose(tf.gather_nd(input_arr, indices)),
fill_value,
)
if fill_mode == "constant":
all_valid = tf.reduce_all(validities)
gathered = tf.where(all_valid, gathered, fill_value)

contribution = tf.cond(tf.reduce_all(validities), fast_path, slow_path)
contribution = gathered
outputs.append(
functools.reduce(operator.mul, weights)
* tf.cast(contribution, weights[0].dtype)
)

result = functools.reduce(operator.add, outputs)

if input_arr.dtype.is_integer:
result = result if result.dtype.is_integer else tf.round(result)
result = tf.round(result)
return tf.cast(result, input_arr.dtype)
52 changes: 52 additions & 0 deletions keras/src/ops/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,58 @@ def test_map_coordinates(self):
out = kimage.map_coordinates(input, coordinates, 0)
self.assertEqual(out.shape, coordinates.shape[1:])

def test_map_coordinates_uint8(self):
image_uint8 = tf.ones((1, 1, 3), dtype=tf.uint8)
coordinates = tf.convert_to_tensor([-1.0, 0.0, 0.0])[..., None, None]

if backend.backend() != "tensorflow":
pytest.skip("Skipping test because the backend is not TensorFlow.")

out = kimage.map_coordinates(
image_uint8, coordinates, order=1, fill_mode="constant"
)
assert out.shape == coordinates.shape[1:]

def test_map_coordinates_float32(self):
image_float32 = tf.ones((1, 1, 3), dtype=tf.float32)
coordinates = tf.convert_to_tensor([-1.0, 0.0, 0.0])[..., None, None]

if backend.backend() != "tensorflow":
pytest.skip("Skipping test because the backend is not TensorFlow.")

out = kimage.map_coordinates(
image_float32, coordinates, order=1, fill_mode="constant"
)
assert out.shape == coordinates.shape[1:]

def test_map_coordinates_nearest(self):
image_uint8 = tf.ones((1, 1, 3), dtype=tf.uint8)
coordinates = tf.convert_to_tensor([-1.0, 0.0, 0.0])[..., None, None]

if backend.backend() != "tensorflow":
pytest.skip("Skipping test because the backend is not TensorFlow.")

out = kimage.map_coordinates(
image_uint8, coordinates, order=1, fill_mode="nearest"
)
assert out.shape == coordinates.shape[1:]

def test_map_coordinates_manual_cast(self):
image_uint8 = tf.ones((1, 1, 3), dtype=tf.uint8)
coordinates = tf.convert_to_tensor([-1.0, 0.0, 0.0])[..., None, None]
image_uint8_casted = tf.cast(image_uint8, dtype=tf.float32)

if backend.backend() != "tensorflow":
pytest.skip("Skipping test because the backend is not TensorFlow.")

out = tf.cast(
kimage.map_coordinates(
image_uint8_casted, coordinates, order=1, fill_mode="constant"
),
dtype=tf.uint8,
)
assert out.shape == coordinates.shape[1:]

def test_pad_images(self):
# Test channels_last
x = KerasTensor([15, 25, 3])
Expand Down
Loading