Description
tensorflow 2.2
keras 2.3.0-tf
tfa 0.10.0
python 3.6
use code:
metrics = [
tfa.metrics.RSquare(name='RSquare', dtype=tf.float32, y_shape=(1,))
]
run multiple epochs, it went well for the first epoch, but return error information for the first epoch.
Error Info:
history = model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=10, callbacks = callbacks)
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
64 def _method_wrapper(self, *args, **kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self, *args, **kwargs)
67
68 # Running inside run_distribute_coordinator
already.
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
870 workers=workers,
871 use_multiprocessing=use_multiprocessing,
--> 872 return_dict=True)
873 val_logs = {'val_' + name: val for name, val in val_logs.items()}
874 epoch_logs.update(val_logs)
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
64 def _method_wrapper(self, *args, **kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self, *args, **kwargs)
67
68 # Running inside run_distribute_coordinator
already.
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in evaluate(self, x, y, batch_size, verbose, sample_weight, steps, callbacks, max_queue_size, workers, use_multiprocessing, return_dict)
1071 callbacks.on_test_begin()
1072 for _, iterator in data_handler.enumerate_epochs(): # Single epoch.
-> 1073 self.reset_metrics()
1074 with data_handler.catch_stop_iteration():
1075 for step in data_handler.steps():
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in reset_metrics(self)
1289 """Resets the state of metrics."""
1290 for m in self.metrics:
-> 1291 m.reset_states()
1292
1293 def train_on_batch(self,
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow_addons/metrics/r_square.py in reset_states(self)
140 def reset_states(self) -> None:
141 # The state of the metric will be reset at the start of each epoch.
--> 142 self.squared_sum.assign(0)
143 self.sum.assign(0)
144 self.res.assign(0)
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py in assign(self, value, use_locking, name, read_value)
844 with _handle_graph(self.handle):
845 value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
--> 846 self._shape.assert_is_compatible_with(value_tensor.shape)
847 assign_op = gen_resource_variable_ops.assign_variable_op(
848 self.handle, value_tensor, name=name)
~/environment/anaconda2/envs/tf2/lib/python3.6/site-packages/tensorflow/python/framework/tensor_shape.py in assert_is_compatible_with(self, other)
1115 """
1116 if not self.is_compatible_with(other):
-> 1117 raise ValueError("Shapes %s and %s are incompatible" % (self, other))
1118
1119 def most_specific_compatible_shape(self, other):
ValueError: Shapes (1,) and () are incompatible