Skip to content

Commit 1ba3b8f

Browse files
Fix discretization discrepancy (#21769)
* fix discretization discrepancy * fix resolutoin of output type and update test * fix dtype comparison for torch * remove duplicate assert
1 parent 53987a7 commit 1ba3b8f

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

keras/src/layers/preprocessing/discretization.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,6 @@ def __init__(
9595
dtype=None,
9696
name=None,
9797
):
98-
if dtype is None:
99-
dtype = "int64" if output_mode == "int" else backend.floatx()
100-
10198
super().__init__(name=name, dtype=dtype)
10299

103100
if sparse and not backend.SUPPORTS_SPARSE_TENSORS:
@@ -155,6 +152,10 @@ def __init__(
155152
def input_dtype(self):
156153
return backend.floatx()
157154

155+
@property
156+
def output_dtype(self):
157+
return self.compute_dtype if self.output_mode != "int" else "int32"
158+
158159
def adapt(self, data, steps=None):
159160
"""Computes bin boundaries from quantiles in a input dataset.
160161
@@ -213,7 +214,7 @@ def reset_state(self):
213214
self.summary = np.array([[], []], dtype="float32")
214215

215216
def compute_output_spec(self, inputs):
216-
return backend.KerasTensor(shape=inputs.shape, dtype=self.compute_dtype)
217+
return backend.KerasTensor(shape=inputs.shape, dtype=self.output_dtype)
217218

218219
def load_own_variables(self, store):
219220
if len(store) == 1:
@@ -234,7 +235,7 @@ def call(self, inputs):
234235
indices,
235236
output_mode=self.output_mode,
236237
depth=len(self.bin_boundaries) + 1,
237-
dtype=self.compute_dtype,
238+
dtype=self.output_dtype,
238239
sparse=self.sparse,
239240
backend_module=self.backend,
240241
)

keras/src/layers/preprocessing/discretization_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,29 @@ def test_call_before_adapt_raises(self):
205205
layer = layers.Discretization(num_bins=3)
206206
with self.assertRaisesRegex(ValueError, "You need .* call .*adapt"):
207207
layer([[0.1, 0.8, 0.9]])
208+
209+
def test_model_call_vs_predict_consistency(self):
210+
"""Test that model(input) and model.predict(input) produce consistent outputs.""" # noqa: E501
211+
# Test with int output mode
212+
layer = layers.Discretization(
213+
bin_boundaries=[-0.5, 0, 0.1, 0.2, 3],
214+
output_mode="int",
215+
)
216+
x = np.array([[0.0, 0.15, 0.21, 0.3], [0.0, 0.17, 0.451, 7.8]])
217+
218+
# Create model
219+
inputs = layers.Input(shape=(4,), dtype="float32")
220+
outputs = layer(inputs)
221+
model = models.Model(inputs=inputs, outputs=outputs)
222+
223+
# Test both execution modes
224+
model_call_output = model(x)
225+
predict_output = model.predict(x)
226+
227+
# Check consistency
228+
self.assertAllClose(model_call_output, predict_output)
229+
self.assertEqual(
230+
backend.standardize_dtype(model_call_output.dtype),
231+
backend.standardize_dtype(predict_output.dtype),
232+
)
233+
self.assertTrue(backend.is_int_dtype(model_call_output.dtype))

0 commit comments

Comments
 (0)