Skip to content

tfa.metrics.RSquare: ValueError: Shapes (1,) and () are incompatible #2520

Open
@FeiCoding

Description

@FeiCoding

tensorflow 2.2
keras 2.3.0-tf
tfa 0.10.0
python 3.6

use code:

metrics = [
tfa.metrics.RSquare(name='RSquare', dtype=tf.float32, y_shape=(1,))
]

run multiple epochs, it went well for the first epoch, but return error information for the first epoch.

Error Info:

history = model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=10, callbacks = callbacks)

~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
64 def _method_wrapper(self, *args, **kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self, *args, **kwargs)
67
68 # Running inside run_distribute_coordinator already.

~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
870 workers=workers,
871 use_multiprocessing=use_multiprocessing,
--> 872 return_dict=True)
873 val_logs = {'val_' + name: val for name, val in val_logs.items()}
874 epoch_logs.update(val_logs)

~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
64 def _method_wrapper(self, *args, **kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self, *args, **kwargs)
67
68 # Running inside run_distribute_coordinator already.

~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in evaluate(self, x, y, batch_size, verbose, sample_weight, steps, callbacks, max_queue_size, workers, use_multiprocessing, return_dict)
1071 callbacks.on_test_begin()
1072 for _, iterator in data_handler.enumerate_epochs(): # Single epoch.
-> 1073 self.reset_metrics()
1074 with data_handler.catch_stop_iteration():
1075 for step in data_handler.steps():

~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in reset_metrics(self)
1289 """Resets the state of metrics."""
1290 for m in self.metrics:
-> 1291 m.reset_states()
1292
1293 def train_on_batch(self,

~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow_addons/metrics/r_square.py in reset_states(self)
140 def reset_states(self) -> None:
141 # The state of the metric will be reset at the start of each epoch.
--> 142 self.squared_sum.assign(0)
143 self.sum.assign(0)
144 self.res.assign(0)

~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py in assign(self, value, use_locking, name, read_value)
844 with _handle_graph(self.handle):
845 value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
--> 846 self._shape.assert_is_compatible_with(value_tensor.shape)
847 assign_op = gen_resource_variable_ops.assign_variable_op(
848 self.handle, value_tensor, name=name)

~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/framework/tensor_shape.py in assert_is_compatible_with(self, other)
1115 """
1116 if not self.is_compatible_with(other):
-> 1117 raise ValueError("Shapes %s and %s are incompatible" % (self, other))
1118
1119 def most_specific_compatible_shape(self, other):

ValueError: Shapes (1,) and () are incompatible

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions