Skip to content

Revert "Fix dtype bug in image converter (#2147)" #2180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions keras_hub/src/layers/preprocessing/image_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty
from keras_hub.src.utils.tensor_utils import check_bounding_box_support
from keras_hub.src.utils.tensor_utils import in_tf_function
from keras_hub.src.utils.tensor_utils import preprocessing_function


Expand Down Expand Up @@ -270,36 +271,45 @@ def call(self, inputs):
else:
x = inputs
if self.scale is not None:
x = x * self._expand_non_channel_dims(self.scale, x)
# If we are scaling always cast to the compute dtype. We can't
# leave things as an int type if we are scaling to [0, 1].
scale = self._expand_non_channel_dims(self.scale, x)
x, scale = self._convert_types(x, scale, self.compute_dtype)
x = x * scale
if self.offset is not None:
x = x + self._expand_non_channel_dims(self.offset, x)
offset = self._expand_non_channel_dims(self.offset, x)
x, offset = self._convert_types(x, offset, x.dtype)
x = x + offset
if isinstance(inputs, dict):
inputs["images"] = x
else:
inputs = x
return inputs

def _expand_non_channel_dims(self, value, inputs):
input_dtype = keras.backend.standardize_dtype(inputs.dtype)

"""Expand non channel dims so value is broadcastable with inputs."""
unbatched = len(ops.shape(inputs)) == 3
channels_first = self.data_format == "channels_first"
if unbatched:
broadcast_dims = (1, 2) if channels_first else (0, 1)
else:
broadcast_dims = (0, 2, 3) if channels_first else (0, 1, 2)
# If inputs are not a tensor type, return a numpy array.
# This might happen when running under tf.data.
if ops.is_tensor(inputs):
# preprocessing decorator moves tensors to cpu in torch backend and
# processed on CPU, and then converted back to the appropriate
# device (potentially GPU) after preprocessing.
if keras.backend.backend() == "torch" and self.image_size is None:
return ops.expand_dims(value, broadcast_dims).cpu()
expanded = ops.expand_dims(value, broadcast_dims)
return ops.cast(expanded, input_dtype)
else:
return np.expand_dims(value, broadcast_dims).astype(input_dtype)
# An numpy value will work backend native ops or with tf.data.
return np.expand_dims(value, broadcast_dims)

def _convert_types(self, x, y, dtype):
"""Make sure x and y have the same dtype and are on ths same device."""
if in_tf_function():
# This could happen on any backend if we are running in tf.data.
import tensorflow as tf

return tf.cast(x, dtype), tf.cast(y, dtype)
x = ops.cast(x, dtype)
y = ops.cast(y, dtype)
if keras.backend.backend() == "torch":
# Place on the same device as x (the image).
y = y.to(x.device)
return x, y

def get_config(self):
config = super().get_config()
Expand Down
32 changes: 20 additions & 12 deletions keras_hub/src/layers/preprocessing/image_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import keras
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized
from keras import ops

Expand All @@ -22,6 +23,12 @@ def test_resize_simple(self):
outputs = converter(inputs)
self.assertAllClose(outputs, ops.ones((4, 4, 3)))

def test_resize_dataset(self):
converter = ImageConverter(image_size=(4, 4), scale=1 / 255.0)
ds = tf.data.Dataset.from_tensor_slices(tf.zeros((8, 10, 10, 3)))
batch = ds.batch(2).map(converter).take(1).get_single_element()
self.assertAllClose(batch, tf.zeros((2, 4, 4, 3)))

def test_unbatched(self):
converter = ImageConverter(
image_size=(4, 4),
Expand All @@ -35,20 +42,21 @@ def test_unbatched(self):
self.assertAllClose(outputs[:, :, 1], np.ones((4, 4)) * 0.301569)
self.assertAllClose(outputs[:, :, 2], np.ones((4, 4)) * 0.852353)

def test_bfloat16_input(self):
def test_dtypes(self):
converter = ImageConverter(image_size=(4, 4), scale=1.0 / 255.0)
int_image = ops.ones((10, 10, 3), dtype="uint8") * 255
float_image = ops.ones((10, 10, 3), dtype="float64") * 255
self.assertDTypeEqual(converter(int_image), "float32")
self.assertDTypeEqual(converter(float_image), "float32")
self.assertAllClose(converter(int_image), np.ones((4, 4, 3)))
self.assertAllClose(converter(float_image), np.ones((4, 4, 3)))
converter = ImageConverter(
image_size=(4, 4),
scale=(1.0 / 255.0, 0.8 / 255.0, 1.2 / 255.0),
offset=(0.2, -0.1, 0.25),
dtype="bfloat16",
image_size=(4, 4), scale=1.0 / 255.0, dtype="bfloat16"
)
inputs = ops.ones((10, 10, 3)) * 128
inputs = ops.cast(inputs, "bfloat16")
outputs = converter(inputs)
self.assertEqual(ops.shape(outputs), (4, 4, 3))
self.assertAllClose(outputs[:, :, 0], np.ones((4, 4)) * 0.703125)
self.assertAllClose(outputs[:, :, 1], np.ones((4, 4)) * 0.302734)
self.assertAllClose(outputs[:, :, 2], np.ones((4, 4)) * 0.851562)
self.assertDTypeEqual(converter(int_image), "bfloat16")
self.assertDTypeEqual(converter(float_image), "bfloat16")
self.assertAllClose(converter(int_image), np.ones((4, 4, 3)))
self.assertAllClose(converter(float_image), np.ones((4, 4, 3)))

@parameterized.parameters(
(True, False),
Expand Down
11 changes: 8 additions & 3 deletions keras_hub/src/models/vit/vit_image_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,17 @@ def __init__(

@preprocessing_function
def call(self, inputs):
# TODO: Remove this whole function. Why can just use scale and offset
# in the base class.
x = super().call(inputs)
# By default normalize using imagenet mean and std
if self.norm_mean:
x = x - self._expand_non_channel_dims(self.norm_mean, x)
norm_mean = self._expand_non_channel_dims(self.norm_mean, x)
x, norm_mean = self._convert_types(x, norm_mean, self.compute_dtype)
x = x - norm_mean
if self.norm_std:
x = x / self._expand_non_channel_dims(self.norm_std, x)
norm_std = self._expand_non_channel_dims(self.norm_std, x)
x, norm_std = self._convert_types(x, norm_std, x.dtype)
x = x / norm_std

return x

Expand Down
6 changes: 6 additions & 0 deletions keras_hub/src/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def no_convert_scope():
NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) - 1


def in_tf_function():
if tf is None:
return False
return not tf.executing_eagerly()


def in_no_convert_scope():
return getattr(NO_CONVERT_COUNTER, "count", 0) > 0

Expand Down
Loading