Skip to content

Commit 3686bdd

Browse files
author
Jake Schmidt
authored
fix R_Square shape issue in model.evaluate (#2034)
* fix R_Square shape issue in model.evaluate * add test * remove erroneous line in test
1 parent 2e978a2 commit 3686bdd

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

tensorflow_addons/metrics/r_square.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,4 @@ def result(self) -> tf.Tensor:
139139

140140
def reset_states(self) -> None:
141141
# The state of the metric will be reset at the start of each epoch.
142-
K.batch_set_value([(v, 0) for v in self.variables])
142+
K.batch_set_value([(v, tf.zeros_like(v)) for v in self.variables])

tensorflow_addons/metrics/tests/r_square_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,13 @@ def test_r2_sklearn_comparison():
117117
def test_unrecognized_multioutput():
118118
with pytest.raises(ValueError):
119119
initialize_vars(multioutput="meadian")
120+
121+
122+
def test_keras_fit():
123+
model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
124+
model.compile(loss="mse", metrics=[RSquare(y_shape=(1,))])
125+
data = tf.data.Dataset.from_tensor_slices(
126+
(tf.random.normal(shape=(100, 1)), tf.random.normal(shape=(100, 1)))
127+
)
128+
data = data.batch(10)
129+
model.fit(x=data, validation_data=data)

0 commit comments

Comments
 (0)