Skip to content

Commit e9241dd

Browse files
authored
Fix dtype bug in image converter (#2147)
1 parent 6b76c07 commit e9241dd

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

keras_hub/src/layers/preprocessing/image_converter.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ def call(self, inputs):
280280
return inputs
281281

282282
def _expand_non_channel_dims(self, value, inputs):
283+
input_dtype = keras.backend.standardize_dtype(inputs.dtype)
284+
283285
unbatched = len(ops.shape(inputs)) == 3
284286
channels_first = self.data_format == "channels_first"
285287
if unbatched:
@@ -294,9 +296,10 @@ def _expand_non_channel_dims(self, value, inputs):
294296
# device (potentially GPU) after preprocessing.
295297
if keras.backend.backend() == "torch" and self.image_size is None:
296298
return ops.expand_dims(value, broadcast_dims).cpu()
297-
return ops.expand_dims(value, broadcast_dims)
299+
expanded = ops.expand_dims(value, broadcast_dims)
300+
return ops.cast(expanded, input_dtype)
298301
else:
299-
return np.expand_dims(value, broadcast_dims)
302+
return np.expand_dims(value, broadcast_dims).astype(input_dtype)
300303

301304
def get_config(self):
302305
config = super().get_config()

keras_hub/src/layers/preprocessing/image_converter_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@ def test_unbatched(self):
3535
self.assertAllClose(outputs[:, :, 1], np.ones((4, 4)) * 0.301569)
3636
self.assertAllClose(outputs[:, :, 2], np.ones((4, 4)) * 0.852353)
3737

38+
def test_bfloat16_input(self):
39+
converter = ImageConverter(
40+
image_size=(4, 4),
41+
scale=(1.0 / 255.0, 0.8 / 255.0, 1.2 / 255.0),
42+
offset=(0.2, -0.1, 0.25),
43+
dtype="bfloat16",
44+
)
45+
inputs = ops.ones((10, 10, 3)) * 128
46+
inputs = ops.cast(inputs, "bfloat16")
47+
outputs = converter(inputs)
48+
self.assertEqual(ops.shape(outputs), (4, 4, 3))
49+
self.assertAllClose(outputs[:, :, 0], np.ones((4, 4)) * 0.703125)
50+
self.assertAllClose(outputs[:, :, 1], np.ones((4, 4)) * 0.302734)
51+
self.assertAllClose(outputs[:, :, 2], np.ones((4, 4)) * 0.851562)
52+
3853
@parameterized.parameters(
3954
(True, False),
4055
(False, True),

0 commit comments

Comments
 (0)